refactor(util/task_queue): merge module_task into task and cancel by position
This commit is contained in:
parent
aa03dc03b4
commit
e448e4e129
14 changed files with 163 additions and 176 deletions
|
|
@ -573,7 +573,7 @@ static expr inline_new_defs(environment const & old_env, environment const & new
|
|||
});
|
||||
}
|
||||
|
||||
class proof_elaboration_task : public module_task<expr> {
|
||||
class proof_elaboration_task : public task<expr> {
|
||||
environment m_decl_env;
|
||||
options m_opts;
|
||||
bool m_use_info_manager;
|
||||
|
|
@ -596,23 +596,22 @@ public:
|
|||
bool is_rfl_lemma, expr const & final_type,
|
||||
metavar_context const & mctx, local_context const & lctx,
|
||||
parser_pos_provider const & prov) :
|
||||
module_task(optional<pos_info>(prov.get_some_pos()), task_kind::elab),
|
||||
m_decl_env(decl_env), m_opts(opts), m_use_info_manager(get_global_info_manager() != nullptr),
|
||||
m_params(params.begin(), params.end()), m_fn(fn), m_val(val), m_finfo(finfo),
|
||||
m_is_rfl_lemma(is_rfl_lemma), m_final_type(final_type),
|
||||
m_mctx(mctx), m_lctx(lctx), m_pos_provider(prov) {}
|
||||
|
||||
void description(std::ostream & out) const override {
|
||||
out << "proving " << local_pp_name(m_fn) << " (" << get_module() << ")";
|
||||
out << "proving " << local_pp_name(m_fn) << " (" << get_module_id() << ")";
|
||||
}
|
||||
|
||||
expr execute_core() override {
|
||||
expr execute() override {
|
||||
scoped_expr_caching disable(false); // FIXME: otherwise sigma.eq fails to elaborate
|
||||
auto tc = std::make_shared<type_context>(m_decl_env, m_opts, m_mctx, m_lctx);
|
||||
scope_trace_env scope2(m_decl_env, m_opts, *tc);
|
||||
scope_pos_info_provider scope3(m_pos_provider);
|
||||
scoped_info_manager scope4(
|
||||
m_use_info_manager ? get_scope_message_context().enable_info_manager(get_module())
|
||||
m_use_info_manager ? get_scope_message_context().enable_info_manager(get_module_id())
|
||||
: nullptr);
|
||||
|
||||
try {
|
||||
|
|
@ -635,7 +634,7 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class example_checking_task : public module_task<unit> {
|
||||
class example_checking_task : public task<unit> {
|
||||
environment m_decl_env;
|
||||
options m_opts;
|
||||
bool m_use_info_manager;
|
||||
|
|
@ -658,24 +657,25 @@ public:
|
|||
expr const & fn, expr const & val,
|
||||
metavar_context const & mctx, local_context const & lctx,
|
||||
parser_pos_provider const & prov) :
|
||||
module_task(optional<pos_info>(prov.get_some_pos()), task_kind::print),
|
||||
m_decl_env(decl_env), m_opts(opts), m_use_info_manager(get_global_info_manager() != nullptr),
|
||||
m_modifiers(modifiers),
|
||||
m_univ_params(univ_params), m_params(params.begin(), params.end()), m_fn(fn), m_val(val),
|
||||
m_mctx(mctx), m_lctx(lctx), m_pos_provider(prov) {
|
||||
}
|
||||
|
||||
task_kind get_kind() const override { return task_kind::print; }
|
||||
|
||||
void description(std::ostream & out) const override {
|
||||
out << "checking example on line " << m_pos_provider.get_some_pos().first << " (" << get_module() << ")";
|
||||
out << "checking example on line " << m_pos_provider.get_some_pos().first << " (" << get_module_id() << ")";
|
||||
}
|
||||
|
||||
unit execute_core() override {
|
||||
unit execute() override {
|
||||
scoped_expr_caching disable(false); // FIXME: otherwise sigma.eq fails to elaborate
|
||||
auto tc = std::make_shared<type_context>(m_decl_env, m_opts, m_mctx, m_lctx);
|
||||
scope_trace_env scope2(m_decl_env, m_opts, *tc);
|
||||
scope_pos_info_provider scope3(m_pos_provider);
|
||||
scoped_info_manager scope4(
|
||||
m_use_info_manager ? get_scope_message_context().enable_info_manager(get_module())
|
||||
m_use_info_manager ? get_scope_message_context().enable_info_manager(get_module_id())
|
||||
: nullptr);
|
||||
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -2159,6 +2159,7 @@ bool parser::parse_commands() {
|
|||
if (m_stop_at && pos().first > m_stop_at_line) {
|
||||
throw interrupt_parser();
|
||||
}
|
||||
scoped_task_context scope_task_ctx(get_current_module(), pos());
|
||||
scope_message_context scope_msg_ctx;
|
||||
scoped_info_manager scope_infom( // TODO(gabriel): separate flag for snapshots/infos?
|
||||
m_snapshot_vector ? scope_msg_ctx.enable_info_manager(m_file_name)
|
||||
|
|
|
|||
|
|
@ -757,12 +757,12 @@ static void check_definition(environment const & env, declaration const & d, typ
|
|||
}
|
||||
}
|
||||
|
||||
class proof_checking_task : public module_task<expr> {
|
||||
class proof_checking_task : public task<expr> {
|
||||
environment m_env;
|
||||
declaration m_decl;
|
||||
public:
|
||||
proof_checking_task(environment const & env, declaration const & d) :
|
||||
module_task({}, task_kind::elab), m_env(env), m_decl(d) {
|
||||
m_env(env), m_decl(d) {
|
||||
lean_assert(d.is_theorem());
|
||||
}
|
||||
|
||||
|
|
@ -774,7 +774,7 @@ public:
|
|||
return { m_decl.get_value_task() };
|
||||
}
|
||||
|
||||
expr execute_core() override {
|
||||
expr execute() override {
|
||||
bool memoize = true;
|
||||
bool trusted_only = m_decl.is_trusted();
|
||||
type_checker checker(m_env, memoize, trusted_only);
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ void module_mgr::mark_out_of_date(module_id const & id, buffer<module_id> & to_r
|
|||
}
|
||||
}
|
||||
|
||||
class parse_lean_task : public module_task<module_info::parse_result> {
|
||||
class parse_lean_task : public task<module_info::parse_result> {
|
||||
environment m_initial_env;
|
||||
std::string m_contents;
|
||||
snapshot_vector m_snapshots;
|
||||
|
|
@ -44,13 +44,13 @@ public:
|
|||
parse_lean_task(std::string const & contents, environment const & initial_env,
|
||||
snapshot_vector const & snapshots, bool use_snapshots,
|
||||
std::vector<std::tuple<module_id, module_name, module_info>> const & deps) :
|
||||
module_task(optional<pos_info>(), task_kind::parse),
|
||||
m_initial_env(initial_env), m_contents(contents),
|
||||
m_snapshots(snapshots), m_use_snapshots(use_snapshots),
|
||||
m_deps(deps) {}
|
||||
task_kind get_kind() const override { return task_kind::parse; }
|
||||
|
||||
void description(std::ostream & out) const override {
|
||||
out << "parsing " << get_module();
|
||||
out << "parsing " << get_module_id();
|
||||
}
|
||||
|
||||
std::vector<generic_task_result> get_dependencies() override {
|
||||
|
|
@ -59,7 +59,7 @@ public:
|
|||
return deps;
|
||||
}
|
||||
|
||||
module_info::parse_result execute_core() override {
|
||||
module_info::parse_result execute() override {
|
||||
module_loader import_fn = [=] (module_id const & base, module_name const & import) {
|
||||
for (auto d : m_deps) {
|
||||
if (std::get<0>(d) == base &&
|
||||
|
|
@ -79,7 +79,7 @@ public:
|
|||
|
||||
bool use_exceptions = false;
|
||||
std::istringstream in(m_contents);
|
||||
parser p(m_initial_env, get_global_ios(), import_fn, in, get_module(),
|
||||
parser p(m_initial_env, get_global_ios(), import_fn, in, get_module_id(),
|
||||
use_exceptions,
|
||||
(m_snapshots.empty() || !m_use_snapshots) ? nullptr : &m_snapshots.back(),
|
||||
m_use_snapshots ? &m_snapshots : nullptr);
|
||||
|
|
@ -104,13 +104,12 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class olean_compilation_task : public module_task<unit> {
|
||||
class olean_compilation_task : public task<unit> {
|
||||
module_info m_mod;
|
||||
|
||||
public:
|
||||
olean_compilation_task(module_info const & mod) :
|
||||
module_task(optional<pos_info>(), task_kind::parse),
|
||||
m_mod(mod) {}
|
||||
olean_compilation_task(module_info const & mod) : m_mod(mod) {}
|
||||
task_kind get_kind() const override { return task_kind::parse; }
|
||||
|
||||
std::vector<generic_task_result> get_dependencies() override {
|
||||
if (auto res = m_mod.m_result.peek()) {
|
||||
|
|
@ -125,10 +124,10 @@ public:
|
|||
}
|
||||
|
||||
void description(std::ostream & out) const override {
|
||||
out << "saving object code for " << get_module();
|
||||
out << "saving object code for " << get_module_id();
|
||||
}
|
||||
|
||||
unit execute_core() override {
|
||||
unit execute() override {
|
||||
if (m_mod.m_source != module_src::LEAN)
|
||||
throw exception("cannot build olean from olean");
|
||||
auto res = m_mod.m_result.get();
|
||||
|
|
@ -159,7 +158,7 @@ void module_mgr::build_module(module_id const & id, bool can_use_olean, name_set
|
|||
|
||||
scope_global_ios scope_ios(m_ios);
|
||||
scoped_message_buffer scoped_msg_buf(m_msg_buf);
|
||||
scoped_module_id scoped_mod_mgr(id);
|
||||
scoped_task_context(id, {1, 0});
|
||||
message_bucket_id bucket_id { id, m_current_period };
|
||||
scope_message_context scope_msg_ctx(bucket_id);
|
||||
scope_traces_as_messages scope_trace_msgs(id, {1, 0});
|
||||
|
|
@ -204,6 +203,10 @@ void module_mgr::build_module(module_id const & id, bool can_use_olean, name_set
|
|||
res.m_ok = true;
|
||||
mod.m_result = mk_pure_task_result(res, "Loading " + olean_fn);
|
||||
|
||||
get_global_task_queue().cancel_if(
|
||||
[=] (generic_task * t) {
|
||||
return t->get_version() < m_current_period && t->get_module_id() == id;
|
||||
});
|
||||
if (auto old = m_modules[id].m_result) old.cancel();
|
||||
m_modules[id] = mod;
|
||||
} else if (src == module_src::LEAN) {
|
||||
|
|
@ -218,6 +221,8 @@ void module_mgr::build_module(module_id const & id, bool can_use_olean, name_set
|
|||
return;
|
||||
}
|
||||
}
|
||||
auto task_pos = snapshots.empty() ? pos_info {1, 0} : snapshots.back().m_pos;
|
||||
scoped_task_context scope_task_ctx2(id, task_pos);
|
||||
|
||||
scope_message_context scope_msg_ctx2(bucket_name);
|
||||
|
||||
|
|
@ -248,7 +253,9 @@ void module_mgr::build_module(module_id const & id, bool can_use_olean, name_set
|
|||
if (m_save_olean)
|
||||
mod.m_olean_task = get_global_task_queue().submit<olean_compilation_task>(mod);
|
||||
|
||||
if (auto old = m_modules[id].m_result) old.cancel();
|
||||
get_global_task_queue().cancel_if([=] (generic_task * t) {
|
||||
return t->get_version() < m_current_period && t->get_module_id() == id && t->get_pos() >= task_pos;
|
||||
});
|
||||
m_modules[id] = mod;
|
||||
} else {
|
||||
throw exception("unknown module source");
|
||||
|
|
@ -389,24 +396,4 @@ std::tuple<std::string, module_src, time_t> fs_module_vfs::load_module(module_id
|
|||
return std::make_tuple(read_file(lean_fn), module_src::LEAN, lean_mtime);
|
||||
}
|
||||
|
||||
LEAN_THREAD_PTR(module_id, g_scoped_module_id);
|
||||
scoped_module_id::scoped_module_id(module_id const & module) : m_mod(module) {
|
||||
m_old = g_scoped_module_id;
|
||||
g_scoped_module_id = &m_mod;
|
||||
}
|
||||
scoped_module_id::~scoped_module_id() {
|
||||
g_scoped_module_id = m_old;
|
||||
}
|
||||
module_id const & get_global_module_id() {
|
||||
return *g_scoped_module_id;
|
||||
}
|
||||
|
||||
void generic_module_task::set_result(generic_task_result const & self) {
|
||||
if (m_auto_cancel) {
|
||||
if (auto vmb = dynamic_cast<versioned_msg_buf *>(m_msg_buf))
|
||||
vmb->cancel_when_invalidated(m_bucket, self);
|
||||
}
|
||||
generic_task::set_result(self);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,8 +19,6 @@ Author: Gabriel Ebner
|
|||
|
||||
namespace lean {
|
||||
|
||||
typedef std::string module_id;
|
||||
|
||||
enum class module_src {
|
||||
OLEAN,
|
||||
LEAN,
|
||||
|
|
@ -116,82 +114,4 @@ public:
|
|||
io_state get_io_state() const { return m_ios; }
|
||||
};
|
||||
|
||||
module_id const & get_global_module_id();
|
||||
class scoped_module_id {
|
||||
module_id * m_old;
|
||||
module_id m_mod;
|
||||
public:
|
||||
scoped_module_id(module_id const &);
|
||||
~scoped_module_id();
|
||||
};
|
||||
|
||||
class generic_module_task : public generic_task {
|
||||
public:
|
||||
enum class task_kind { parse, elab, print };
|
||||
|
||||
private:
|
||||
template <class T> friend class module_task;
|
||||
message_buffer * m_msg_buf;
|
||||
io_state m_ios;
|
||||
module_id m_mod;
|
||||
message_bucket_id m_bucket;
|
||||
optional<pos_info> m_pos;
|
||||
bool m_auto_cancel;
|
||||
task_kind m_kind;
|
||||
|
||||
public:
|
||||
generic_module_task(optional<pos_info> const & pos, task_kind kind, bool auto_cancel) :
|
||||
m_msg_buf(&get_global_message_buffer()), m_ios(get_global_ios()),
|
||||
m_mod(get_global_module_id()),
|
||||
m_bucket(get_scope_message_context().new_sub_bucket()),
|
||||
m_pos(pos), m_auto_cancel(auto_cancel), m_kind(kind) {}
|
||||
|
||||
void set_result(generic_task_result const & self) override;
|
||||
|
||||
task_kind get_kind() const { return m_kind; }
|
||||
|
||||
module_id get_module() const { return m_mod; }
|
||||
pos_info get_pos_or_something() const { return m_pos ? *m_pos : pos_info{1, 0}; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class module_task : public task<T>, public generic_module_task {
|
||||
public:
|
||||
module_task(optional<pos_info> const & pos, task_kind kind, bool auto_cancel = true) :
|
||||
generic_module_task(pos, kind, auto_cancel) {}
|
||||
|
||||
void set_result(generic_task_result const & self) override {
|
||||
generic_module_task::set_result(self);
|
||||
}
|
||||
|
||||
virtual T execute_core() = 0;
|
||||
|
||||
T execute() final override;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
T module_task<T>::execute() {
|
||||
scoped_module_id scoped_mod_id(m_mod);
|
||||
scope_global_ios scoped_ios(m_ios);
|
||||
scoped_message_buffer scoped_msg_buf(m_msg_buf);
|
||||
scope_message_context scope_msg_ctx(m_bucket);
|
||||
if (m_auto_cancel && !m_msg_buf->is_bucket_valid(m_bucket)) {
|
||||
throw interrupted();
|
||||
}
|
||||
try {
|
||||
scope_traces_as_messages scope_traces(get_module(), get_pos_or_something());
|
||||
return execute_core();
|
||||
} catch (task_cancellation_exception) {
|
||||
throw;
|
||||
} catch (interrupted) {
|
||||
throw;
|
||||
} catch (throwable & ex) {
|
||||
environment env;
|
||||
message_builder builder(env, m_ios, get_module(), get_pos_or_something(), ERROR);
|
||||
builder.set_exception(ex);
|
||||
builder.report();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ void versioned_msg_buf::start_bucket(message_bucket_id const & bucket) {
|
|||
auto & buf = m_buf[bucket.m_bucket];
|
||||
if (buf.m_version < bucket.m_version) {
|
||||
buf.m_version = bucket.m_version;
|
||||
buf.m_cancel_on_invalidation.reset();
|
||||
buf.m_msgs.clear();
|
||||
buf.m_infom.reset();
|
||||
}
|
||||
|
|
@ -55,16 +54,9 @@ void versioned_msg_buf::finish_bucket(message_bucket_id const & bucket, name_set
|
|||
});
|
||||
}
|
||||
|
||||
void versioned_msg_buf::cancel_bucket(name const & bucket) {
|
||||
auto & bck_buf = m_buf[bucket];
|
||||
bck_buf.m_children.for_each([&] (name const & c) { cancel_bucket(c); });
|
||||
if (auto & t = bck_buf.m_cancel_on_invalidation) { t.cancel(); t.reset(); }
|
||||
}
|
||||
|
||||
void versioned_msg_buf::erase_bucket(name const & bucket) {
|
||||
auto & bck_buf = m_buf[bucket];
|
||||
bck_buf.m_children.for_each([&] (name const & c) { erase_bucket(c); });
|
||||
if (auto & t = bck_buf.m_cancel_on_invalidation) t.cancel();
|
||||
m_buf.erase(bucket);
|
||||
}
|
||||
|
||||
|
|
@ -114,14 +106,4 @@ std::vector<info_manager> versioned_msg_buf::get_info_managers() {
|
|||
return result;
|
||||
}
|
||||
|
||||
void versioned_msg_buf::cancel_when_invalidated(message_bucket_id const & bucket, generic_task_result const & t) {
|
||||
unique_lock<mutex> lock(m_mutex);
|
||||
|
||||
auto & buf = m_buf[bucket.m_bucket];
|
||||
if (buf.m_version < bucket.m_version) {
|
||||
if (auto & t_old = buf.m_cancel_on_invalidation) t_old.cancel();
|
||||
buf.m_cancel_on_invalidation = t;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,15 +20,12 @@ class versioned_msg_buf : public message_buffer {
|
|||
std::unique_ptr<info_manager> m_infom;
|
||||
period m_version = 0;
|
||||
|
||||
generic_task_result m_cancel_on_invalidation;
|
||||
|
||||
name_set m_children;
|
||||
};
|
||||
|
||||
mutex m_mutex;
|
||||
std::unordered_map<name, msg_bucket, name_hash> m_buf;
|
||||
|
||||
void cancel_bucket(name const & bucket);
|
||||
void erase_bucket(name const & bucket);
|
||||
bool is_bucket_valid_core(message_bucket_id const & bucket);
|
||||
|
||||
|
|
@ -41,8 +38,6 @@ public:
|
|||
bool is_bucket_valid(message_bucket_id const & bucket) override;
|
||||
void report_info_manager(message_bucket_id const & bucket, info_manager const & infom) override;
|
||||
|
||||
void cancel_when_invalidated(message_bucket_id const & bucket, generic_task_result const & t);
|
||||
|
||||
std::vector<message> get_messages();
|
||||
std::vector<info_manager> get_info_managers();
|
||||
};
|
||||
|
|
|
|||
|
|
@ -45,6 +45,8 @@ server::server(unsigned num_threads, environment const & initial_env, io_state c
|
|||
m_ios.set_regular_channel(std::make_shared<stderr_channel>());
|
||||
m_ios.set_diagnostic_channel(std::make_shared<stderr_channel>());
|
||||
|
||||
scope_global_ios scoped_ios(m_ios);
|
||||
scoped_message_buffer scope_msg_buf(&m_msg_buf);
|
||||
#if defined(LEAN_MULTI_THREAD)
|
||||
if (num_threads == 0)
|
||||
m_tq = new st_task_queue;
|
||||
|
|
@ -380,23 +382,21 @@ mt_tq_prioritizer mk_interactive_prioritizer(module_id const & roi) {
|
|||
task_priority p;
|
||||
p.m_prio = DEFAULT_PRIO;
|
||||
|
||||
if (auto mod_task = dynamic_cast<generic_module_task *>(t)) {
|
||||
bool in_roi = mod_task->get_module() == roi;
|
||||
bool in_roi = t->get_module_id() == roi;
|
||||
|
||||
if (!in_roi)
|
||||
p.m_not_before = { chrono::steady_clock::now() + chrono::seconds(10) };
|
||||
if (!in_roi)
|
||||
p.m_not_before = { chrono::steady_clock::now() + chrono::seconds(10) };
|
||||
|
||||
switch (mod_task->get_kind()) {
|
||||
case generic_module_task::task_kind::parse:
|
||||
p.m_prio = in_roi ? ROI_PARSING_PRIO : PARSING_PRIO;
|
||||
break;
|
||||
case generic_module_task::task_kind::elab:
|
||||
p.m_prio = in_roi ? ROI_ELAB_PRIO : ELAB_PRIO;
|
||||
break;
|
||||
case generic_module_task::task_kind::print:
|
||||
p.m_prio = in_roi ? ROI_PRINT_PRIO : PRINT_PRIO;
|
||||
break;
|
||||
}
|
||||
switch (t->get_kind()) {
|
||||
case task_kind::parse:
|
||||
p.m_prio = in_roi ? ROI_PARSING_PRIO : PARSING_PRIO;
|
||||
break;
|
||||
case task_kind::elab:
|
||||
p.m_prio = in_roi ? ROI_ELAB_PRIO : ELAB_PRIO;
|
||||
break;
|
||||
case task_kind::print:
|
||||
p.m_prio = in_roi ? ROI_PRINT_PRIO : PRINT_PRIO;
|
||||
break;
|
||||
}
|
||||
|
||||
return p;
|
||||
|
|
|
|||
|
|
@ -46,7 +46,8 @@ mt_task_queue::mt_task_queue(unsigned num_workers) :
|
|||
}) {}
|
||||
|
||||
mt_task_queue::mt_task_queue(unsigned num_workers, mt_tq_prioritizer const & prioritizer) :
|
||||
m_required_workers(num_workers), m_prioritizer(prioritizer) {
|
||||
m_required_workers(num_workers), m_prioritizer(prioritizer),
|
||||
m_ios(get_global_ios()), m_msg_buf(&get_global_message_buffer()) {
|
||||
for (unsigned i = 0; i < num_workers; i++)
|
||||
spawn_worker();
|
||||
}
|
||||
|
|
@ -68,7 +69,10 @@ void mt_task_queue::spawn_worker() {
|
|||
this_worker->m_thread = thread([=] {
|
||||
this_worker->m_interrupt_flag = get_interrupt_flag();
|
||||
|
||||
scope_global_task_queue scope(this);
|
||||
scope_global_task_queue scope_tq(this);
|
||||
scope_global_ios scope_ios(m_ios);
|
||||
scoped_message_buffer scope_msg_buf(m_msg_buf);
|
||||
|
||||
unique_lock<mutex> lock(m_mutex);
|
||||
scoped_add<int> dec_required(m_required_workers, -1);
|
||||
while (true) {
|
||||
|
|
@ -147,6 +151,7 @@ void mt_task_queue::propagate_failure(generic_task_result const & tr) {
|
|||
|
||||
void mt_task_queue::submit(generic_task_result const & t) {
|
||||
unique_lock<mutex> lock(m_mutex);
|
||||
check_interrupted();
|
||||
t->m_task->m_prio = m_prioritizer(t->m_task);
|
||||
if (check_deps(t)) {
|
||||
if (!t->has_evaluated()) enqueue(t);
|
||||
|
|
@ -228,9 +233,28 @@ void mt_task_queue::wait(generic_task_result const & t) {
|
|||
}
|
||||
}
|
||||
|
||||
void mt_task_queue::cancel(generic_task_result const & t) {
|
||||
if (!t) return;
|
||||
void mt_task_queue::cancel_if(const std::function<bool(generic_task *)> & pred) {
|
||||
std::vector<generic_task_result> to_cancel;
|
||||
unique_lock<mutex> lock(m_mutex);
|
||||
|
||||
for (auto & w : m_workers)
|
||||
if (w->m_current_task && pred(w->m_current_task->m_task))
|
||||
to_cancel.push_back(w->m_current_task);
|
||||
|
||||
for (auto & q : m_queue)
|
||||
for (auto & t : q.second)
|
||||
if (t->m_task && pred(t->m_task))
|
||||
to_cancel.push_back(t);
|
||||
|
||||
for (auto & t : m_waiting)
|
||||
if (t->m_task && pred(t->m_task))
|
||||
to_cancel.push_back(t);
|
||||
|
||||
for (auto & t : to_cancel)
|
||||
cancel_core(t);
|
||||
}
|
||||
|
||||
void mt_task_queue::cancel_core(generic_task_result const & t) {
|
||||
switch (t->m_state.load()) {
|
||||
case task_result_state::QUEUED:
|
||||
t->m_ex = std::make_exception_ptr(task_cancellation_exception(t));
|
||||
|
|
@ -249,6 +273,11 @@ void mt_task_queue::cancel(generic_task_result const & t) {
|
|||
default: return;
|
||||
}
|
||||
}
|
||||
void mt_task_queue::cancel(generic_task_result const & t) {
|
||||
if (!t) return;
|
||||
unique_lock<mutex> lock(m_mutex);
|
||||
cancel_core(t);
|
||||
}
|
||||
|
||||
bool mt_task_queue::empty() {
|
||||
unique_lock<mutex> lock(m_mutex);
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ Author: Gabriel Ebner
|
|||
#include <map>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <library/io_state.h>
|
||||
#include "util/optional.h"
|
||||
#include "util/task_queue.h"
|
||||
|
||||
|
|
@ -42,9 +43,11 @@ class mt_task_queue : public task_queue {
|
|||
condition_variable m_wake_up_worker;
|
||||
|
||||
mt_tq_prioritizer m_prioritizer;
|
||||
|
||||
progress_cb m_progress_cb;
|
||||
|
||||
io_state m_ios;
|
||||
message_buffer * m_msg_buf;
|
||||
|
||||
generic_task_result dequeue();
|
||||
void enqueue(generic_task_result const &);
|
||||
|
||||
|
|
@ -52,6 +55,7 @@ class mt_task_queue : public task_queue {
|
|||
void propagate_failure(generic_task_result const &);
|
||||
void submit(generic_task_result const &) override;
|
||||
void bump_prio(generic_task_result const &, task_priority const &);
|
||||
void cancel_core(generic_task_result const &);
|
||||
|
||||
void reprioritize_core();
|
||||
|
||||
|
|
@ -66,6 +70,8 @@ public:
|
|||
void wait(generic_task_result const & t) override;
|
||||
void cancel(generic_task_result const & t) override;
|
||||
|
||||
void cancel_if(const std::function<bool(generic_task *)> &pred) override;
|
||||
|
||||
void set_progress_callback(progress_cb const & cb) override;
|
||||
|
||||
void reprioritize(mt_tq_prioritizer const & p);
|
||||
|
|
|
|||
|
|
@ -36,4 +36,6 @@ void st_task_queue::set_progress_callback(progress_cb const & cb) {
|
|||
m_progress_cb = cb;
|
||||
}
|
||||
|
||||
void st_task_queue::cancel_if(const std::function<bool(generic_task *)> &) {}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ public:
|
|||
void wait(generic_task_result const & t) override;
|
||||
void cancel(generic_task_result const & t) override;
|
||||
|
||||
void cancel_if(const std::function<bool(generic_task *)> &pred) override;
|
||||
|
||||
void set_progress_callback(progress_cb const &) override;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Author: Gabriel Ebner
|
||||
*/
|
||||
#include <string>
|
||||
#include <library/trace.h>
|
||||
#include <library/message_builder.h>
|
||||
#include "util/task_queue.h"
|
||||
|
||||
namespace lean {
|
||||
|
|
@ -15,7 +17,8 @@ std::string generic_task::description() const {
|
|||
return out.str();
|
||||
}
|
||||
|
||||
void generic_task::set_result(generic_task_result const &) {}
|
||||
generic_task::generic_task() : m_bucket(get_scope_message_context().new_sub_bucket()),
|
||||
m_mod(get_current_module()), m_pos(get_current_task_pos()) {}
|
||||
|
||||
generic_task_result_cell::generic_task_result_cell(generic_task * t) :
|
||||
m_rc(0), m_task(t), m_desc(t->description()) {}
|
||||
|
|
@ -30,9 +33,25 @@ void generic_task_result_cell::clear_task() {
|
|||
bool generic_task_result_cell::execute() {
|
||||
lean_assert(!has_evaluated());
|
||||
try {
|
||||
execute_and_store_result();
|
||||
scoped_task_context ctx(m_task->get_module_id(), m_task->get_task_pos());
|
||||
scope_message_context scope_msg_ctx(m_task->get_bucket());
|
||||
try {
|
||||
scope_traces_as_messages scope_traces(m_task->get_module_id(), m_task->get_pos());
|
||||
execute_and_store_result();
|
||||
} catch (task_cancellation_exception) {
|
||||
throw;
|
||||
} catch (interrupted) {
|
||||
throw;
|
||||
} catch (throwable & ex) {
|
||||
environment env;
|
||||
message_builder builder(env, get_global_ios(), m_task->get_module_id(), m_task->get_pos(), ERROR);
|
||||
builder.set_exception(ex);
|
||||
builder.report();
|
||||
throw;
|
||||
}
|
||||
return true;
|
||||
} catch (interrupted) {
|
||||
std::cerr << "interrupted: " << m_desc << std::endl;
|
||||
m_ex = std::make_exception_ptr(
|
||||
task_cancellation_exception(generic_task_result(this)));
|
||||
return false;
|
||||
|
|
@ -69,4 +88,20 @@ char const *task_cancellation_exception::what() const noexcept {
|
|||
return m_msg.c_str();
|
||||
}
|
||||
|
||||
LEAN_THREAD_PTR(module_id, g_cur_mod);
|
||||
LEAN_THREAD_PTR(pos_info, g_cur_task_pos);
|
||||
scoped_task_context::scoped_task_context(module_id const & mod, pos_info const & pos) : m_id(mod), m_pos(pos) {
|
||||
m_old_id = g_cur_mod;
|
||||
m_old_pos = g_cur_task_pos;
|
||||
g_cur_mod = &m_id;
|
||||
g_cur_task_pos = &m_pos;
|
||||
}
|
||||
scoped_task_context::~scoped_task_context() {
|
||||
g_cur_mod = m_old_id;
|
||||
g_cur_task_pos = m_old_pos;
|
||||
}
|
||||
|
||||
module_id get_current_module() { return *g_cur_mod; }
|
||||
pos_info get_current_task_pos() { return *g_cur_task_pos; }
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ Author: Gabriel Ebner
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
#include <library/message_buffer.h>
|
||||
#include "util/thread.h"
|
||||
#include "util/optional.h"
|
||||
#include "util/rc.h"
|
||||
|
|
@ -103,7 +104,24 @@ struct task_priority {
|
|||
}
|
||||
};
|
||||
|
||||
typedef std::string module_id;
|
||||
enum class task_kind { parse, elab, print };
|
||||
|
||||
module_id get_current_module();
|
||||
pos_info get_current_task_pos();
|
||||
class scoped_task_context {
|
||||
module_id * m_old_id;
|
||||
pos_info * m_old_pos;
|
||||
module_id m_id;
|
||||
pos_info m_pos;
|
||||
|
||||
public:
|
||||
scoped_task_context(module_id const & mod, pos_info const & pos);
|
||||
~scoped_task_context();
|
||||
};
|
||||
|
||||
class generic_task {
|
||||
template <class T> friend class task;
|
||||
friend class task_queue;
|
||||
friend class st_task_queue;
|
||||
friend class mt_task_queue;
|
||||
|
|
@ -112,24 +130,34 @@ class generic_task {
|
|||
std::vector<generic_task_result> m_reverse_deps;
|
||||
condition_variable m_has_finished;
|
||||
|
||||
// metadata
|
||||
message_bucket_id m_bucket;
|
||||
module_id m_mod;
|
||||
pos_info m_pos;
|
||||
|
||||
public:
|
||||
generic_task();
|
||||
virtual ~generic_task() {}
|
||||
|
||||
virtual void description(std::ostream &) const = 0;
|
||||
std::string description() const;
|
||||
virtual std::vector<generic_task_result> get_dependencies() { return {}; }
|
||||
|
||||
virtual void set_result(generic_task_result const & self);
|
||||
|
||||
virtual bool is_tiny() const { return false; }
|
||||
virtual task_kind get_kind() const { return task_kind::elab; }
|
||||
virtual pos_info get_pos() const { return get_task_pos(); }
|
||||
|
||||
message_bucket_id const & get_bucket() const { return m_bucket; }
|
||||
period get_version() const { return m_bucket.m_version; }
|
||||
module_id const & get_module_id() const { return m_mod; }
|
||||
pos_info const & get_task_pos() const { return m_pos; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class task : public generic_task {
|
||||
public:
|
||||
typedef T result;
|
||||
|
||||
virtual ~task() {}
|
||||
|
||||
virtual T execute() = 0;
|
||||
};
|
||||
|
||||
|
|
@ -211,7 +239,6 @@ public:
|
|||
task_result<typename T::result> task(
|
||||
new task_result_cell<typename T::result>(
|
||||
new T(std::forward<As>(args)...)));
|
||||
task->m_task->set_result(task);
|
||||
submit(task);
|
||||
return task;
|
||||
}
|
||||
|
|
@ -233,6 +260,7 @@ public:
|
|||
virtual void wait(generic_task_result const & t) = 0;
|
||||
|
||||
virtual void cancel(generic_task_result const & t) = 0;
|
||||
virtual void cancel_if(std::function<bool(generic_task *)> const & pred) = 0; // NOLINT
|
||||
|
||||
using progress_cb = std::function<void(generic_task *)>; // NOLINT
|
||||
// disabling lint because it this this is cast ^^^
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue