refactor(frontends/lean/elaborator): snapshots

This commit is contained in:
Leonardo de Moura 2016-07-29 10:36:14 -07:00
parent aae33e02b0
commit d05e5422f9
5 changed files with 103 additions and 67 deletions

View file

@ -85,13 +85,13 @@ static std::string pos_string_for(expr const & e) {
level elaborator::mk_univ_metavar() {
level r = m_ctx.mk_univ_metavar_decl();
m_uvar_stack.push_back(r);
m_uvar_stack = cons(r, m_uvar_stack);
return r;
}
expr elaborator::mk_metavar(expr const & A) {
expr r = copy_tag(A, m_ctx.mk_metavar_decl(m_ctx.lctx(), A));
m_mvar_stack.push_back(r);
m_mvar_stack = cons(r, m_mvar_stack);
return r;
}
@ -116,7 +116,7 @@ expr elaborator::mk_instance_core(expr const & C) {
expr elaborator::mk_instance(expr const & C) {
if (has_expr_metavar(C)) {
expr inst = mk_metavar(C);
m_instance_stack.push_back(inst);
m_instance_stack = cons(inst, m_instance_stack);
return inst;
} else {
return mk_instance_core(C);
@ -345,10 +345,10 @@ expr elaborator::visit_prenum(expr const & e, optional<expr> const & expected_ty
if (expected_type) {
A = *expected_type;
if (is_metavar(*expected_type))
m_numeral_type_stack.push_back(A);
m_numeral_type_stack = cons(A, m_numeral_type_stack);
} else {
A = mk_type_metavar();
m_numeral_type_stack.push_back(A);
m_numeral_type_stack = cons(A, m_numeral_type_stack);
}
level A_lvl = get_level(A);
levels ls(A_lvl);
@ -688,13 +688,11 @@ expr elaborator::visit_default_app_core(expr const & fn, arg_mask amask, buffer<
if (auto new_r = ensure_has_type(r, instantiate_mvars(type), *expected_type)) {
return *new_r;
} else {
throw elaborator_exception(ref, format("invalid application") + pp_indent(r) +
line() + format("has type") + pp_indent(type) +
line() + format("but is expected to have type") + pp_indent(*expected_type));
/* We do not generate the error here because we can produce a better one from
the caller (i.e., place the set the expected_type) */
}
} else {
return r;
}
return r;
}
expr elaborator::visit_default_app(expr const & fn, arg_mask amask, buffer<expr> const & args,
@ -711,23 +709,36 @@ expr elaborator::visit_overloaded_app(buffer<expr> const & fns, buffer<expr> con
optional<expr> const & expected_type, expr const & ref) {
trace_elab_detail(tout() << "overloaded application at " << pos_string_for(ref);
tout() << pp_overloads(fns) << "\n";);
unsigned initial_inst_stack_sz = m_instance_stack.size();
list<expr> saved_instance_stack = m_instance_stack;
buffer<expr> new_args;
for (expr const & arg : args) {
new_args.push_back(visit(arg, none_expr()));
}
metavar_context mctx = m_ctx.mctx();
buffer<pair<expr, metavar_context>> candidates;
snapshot S(*this);
buffer<pair<expr, snapshot>> candidates;
buffer<elaborator_exception> error_msgs;
for (expr const & fn : fns) {
try {
checkpoint C(*this);
m_ctx.set_mctx(mctx);
// Restore state
S.restore(*this);
expr c = visit_overload_candidate(fn, new_args, expected_type, ref);
try_to_synthesize_type_class_instances(initial_inst_stack_sz);
candidates.emplace_back(c, m_ctx.mctx());
C.commit();
try_to_synthesize_type_class_instances(saved_instance_stack);
if (expected_type) {
expr c_type = infer_type(c);
if (ensure_has_type(c, c_type, *expected_type)) {
candidates.emplace_back(c, snapshot(*this));
} else {
throw elaborator_exception(ref, format("invalid overload, expression") + pp_indent(c) +
line() + format("has type") + pp_indent(c_type) +
line() + format("but is expected to have type") + pp_indent(*expected_type));
}
} else {
candidates.emplace_back(c, snapshot(*this));
}
} catch (elaborator_exception & ex) {
error_msgs.push_back(ex);
} catch (exception & ex) {
@ -737,7 +748,9 @@ expr elaborator::visit_overloaded_app(buffer<expr> const & fns, buffer<expr> con
lean_assert(candidates.size() + error_msgs.size() == fns.size());
if (candidates.empty()) {
format r("none of the overloads is applicable");
S.restore(*this);
format r("none of the overloads are applicable");
lean_assert(error_msgs.size() == fns.size());
for (unsigned i = 0; i < fns.size(); i++) {
if (i > 0) r += line();
@ -747,6 +760,8 @@ expr elaborator::visit_overloaded_app(buffer<expr> const & fns, buffer<expr> con
}
throw elaborator_exception(ref, r);
} else if (candidates.size() > 1) {
S.restore(*this);
options new_opts = m_opts.update_if_undef(get_pp_full_names_name(), true);
flet<options> set_opts(m_opts, new_opts);
format r("ambiguous overload, possible interpretations");
@ -755,7 +770,8 @@ expr elaborator::visit_overloaded_app(buffer<expr> const & fns, buffer<expr> con
}
throw elaborator_exception(ref, r);
} else {
m_ctx.set_mctx(candidates[0].second);
// Restore successful state
candidates[0].second.restore(*this);
return candidates[0].first;
}
}
@ -917,7 +933,7 @@ expr elaborator::visit_let(expr const & e, optional<expr> const & expected_type)
lean_unreachable();
}
expr elaborator::visit_placeholder(expr const & e, optional<expr> const & expected_type) {
expr elaborator::visit_placeholder(expr const &, optional<expr> const & expected_type) {
if (expected_type)
return mk_metavar(*expected_type);
else
@ -980,23 +996,29 @@ expr elaborator::get_default_numeral_type() {
}
void elaborator::ensure_numeral_types_assigned(checkpoint const & C) {
for (unsigned i = C.m_numeral_type_stack_sz; i < m_numeral_type_stack.size(); i++) {
expr A = instantiate_mvars(m_numeral_type_stack[i]);
list<expr> old_stack = C.m_saved_numeral_type_stack;
while (!is_eqp(m_numeral_type_stack, old_stack)) {
lean_assert(m_numeral_type_stack);
expr A = instantiate_mvars(head(m_numeral_type_stack));
if (is_metavar(A)) {
if (!assign_mvar(A, get_default_numeral_type()))
throw elaborator_exception(A, format("invalid numeral, failed to force numeral to be a nat"));
}
m_numeral_type_stack = tail(m_numeral_type_stack);
}
}
void elaborator::synthesize_type_class_instances_core(unsigned old_sz, bool force) {
unsigned j = old_sz;
for (unsigned i = old_sz; i < m_instance_stack.size(); i++) {
lean_assert(is_metavar(m_instance_stack[i]));
metavar_decl mdecl = *m_ctx.mctx().get_metavar_decl(m_instance_stack[i]);
expr inst = instantiate_mvars(m_instance_stack[i]);
void elaborator::synthesize_type_class_instances_core(list<expr> const & old_stack, bool force) {
buffer<expr> to_keep;
while (!is_eqp(m_instance_stack, old_stack)) {
lean_assert(m_instance_stack);
lean_assert(is_metavar(head(m_instance_stack)));
expr mvar = head(m_instance_stack);
metavar_decl mdecl = *m_ctx.mctx().get_metavar_decl(mvar);
expr inst = instantiate_mvars(mvar);
m_instance_stack = tail(m_instance_stack);
if (!has_expr_metavar(inst)) {
trace_elab(tout() << "skipping type class resolution at " << pos_string_for(m_instance_stack[i])
trace_elab(tout() << "skipping type class resolution at " << pos_string_for(mvar)
<< ", placeholder instantiated using type inference\n";);
continue;
}
@ -1004,21 +1026,20 @@ void elaborator::synthesize_type_class_instances_core(unsigned old_sz, bool forc
if (!has_expr_metavar(inst_type)) {
// We must try to synthesize instance using the local context where it was declared
if (!is_def_eq(inst, mk_instance_core(mdecl.get_context(), inst_type)))
throw elaborator_exception(m_instance_stack[i],
throw elaborator_exception(mvar,
format("failed to assign type class instance to placeholder"));
} else {
if (force) {
throw elaborator_exception(m_instance_stack[i],
throw elaborator_exception(mvar,
format("type class instance cannot be synthesized, type has metavariables") +
pp_indent(inst_type));
} else {
m_instance_stack[j] = m_instance_stack[i];
j++;
to_keep.push_back(mvar);
}
}
}
if (!force)
m_instance_stack.shrink(j);
for (expr const & mvar : to_keep)
m_instance_stack = cons(mvar, m_instance_stack);
}
void elaborator::invoke_tactics(checkpoint const & C) {
@ -1026,30 +1047,40 @@ void elaborator::invoke_tactics(checkpoint const & C) {
}
void elaborator::ensure_no_unassigned_metavars(checkpoint const & C) {
// TODO(Leo)
}
void elaborator::process_checkpoint(checkpoint const & C) {
void elaborator::process_checkpoint(checkpoint & C) {
ensure_numeral_types_assigned(C);
synthesize_type_class_instances(C);
invoke_tactics(C);
ensure_no_unassigned_metavars(C);
C.commit();
}
elaborator::snapshot::snapshot(elaborator & e):
m_saved_mctx(e.m_ctx.mctx()),
m_saved_uvar_stack(e.m_uvar_stack),
m_saved_mvar_stack(e.m_mvar_stack),
m_saved_instance_stack(e.m_instance_stack),
m_saved_numeral_type_stack(e.m_numeral_type_stack) {}
void elaborator::snapshot::restore(elaborator & e) {
e.m_ctx.set_mctx(m_saved_mctx);
e.m_uvar_stack = m_saved_uvar_stack;
e.m_mvar_stack = m_saved_mvar_stack;
e.m_instance_stack = m_saved_instance_stack;
e.m_numeral_type_stack = m_saved_numeral_type_stack;
}
elaborator::checkpoint::checkpoint(elaborator & e):
snapshot(e),
m_elaborator(e),
m_commit(false),
m_uvar_stack_sz(e.m_uvar_stack.size()),
m_mvar_stack_sz(e.m_mvar_stack.size()),
m_instance_stack_sz(e.m_instance_stack.size()),
m_numeral_type_stack_sz(e.m_numeral_type_stack.size()) {
}
m_commit(false) {}
elaborator::checkpoint::~checkpoint() {
if (!m_commit) {
m_elaborator.m_uvar_stack.shrink(m_uvar_stack_sz);
m_elaborator.m_mvar_stack.shrink(m_mvar_stack_sz);
m_elaborator.m_instance_stack.shrink(m_instance_stack_sz);
m_elaborator.m_numeral_type_stack.shrink(m_numeral_type_stack_sz);
restore(m_elaborator);
}
}

View file

@ -24,18 +24,24 @@ class elaborator {
local_level_decls m_local_level_decls;
type_context m_ctx;
buffer<level> m_uvar_stack;
buffer<expr> m_mvar_stack;
buffer<expr> m_instance_stack;
buffer<expr> m_numeral_type_stack;
list<level> m_uvar_stack;
list<expr> m_mvar_stack;
list<expr> m_instance_stack;
list<expr> m_numeral_type_stack;
struct checkpoint {
struct snapshot {
metavar_context m_saved_mctx;
list<level> m_saved_uvar_stack;
list<expr> m_saved_mvar_stack;
list<expr> m_saved_instance_stack;
list<expr> m_saved_numeral_type_stack;
snapshot(elaborator & elab);
void restore(elaborator & elab);
};
struct checkpoint : public snapshot {
elaborator & m_elaborator;
bool m_commit;
unsigned m_uvar_stack_sz;
unsigned m_mvar_stack_sz;
unsigned m_instance_stack_sz;
unsigned m_numeral_type_stack_sz;
checkpoint(elaborator & e);
~checkpoint();
void commit();
@ -84,7 +90,6 @@ class elaborator {
expr instantiate_mvars(expr const & e);
bool is_uvar_assigned(level const & l) const { return m_ctx.is_assigned(l); }
bool is_mvar_assigned(expr const & e) const { return m_ctx.is_assigned(e); }
void resolve_instances_from(unsigned old_sz);
level mk_univ_metavar();
expr mk_metavar(expr const & A);
@ -143,16 +148,16 @@ class elaborator {
expr visit(expr const & e, optional<expr> const & expected_type);
void ensure_numeral_types_assigned(checkpoint const & C);
void synthesize_type_class_instances_core(unsigned old_sz, bool force);
void try_to_synthesize_type_class_instances(unsigned old_sz) {
synthesize_type_class_instances_core(old_sz, false);
void synthesize_type_class_instances_core(list<expr> const & old_stack, bool force);
void try_to_synthesize_type_class_instances(list<expr> const & old_stack) {
synthesize_type_class_instances_core(old_stack, false);
}
void synthesize_type_class_instances(checkpoint const & C) {
synthesize_type_class_instances_core(C.m_instance_stack_sz, true);
synthesize_type_class_instances_core(C.m_saved_instance_stack, true);
}
void invoke_tactics(checkpoint const & C);
void ensure_no_unassigned_metavars(checkpoint const & C);
void process_checkpoint(checkpoint const & C);
void process_checkpoint(checkpoint & C);
public:
elaborator(environment const & env, options const & opts, local_level_decls const & lls,

View file

@ -3,7 +3,7 @@ elab11.lean:6:6: error: ambiguous overload, possible interpretations
boo.f 1
bla.f 1 :
boo.f 1 : bool
elab11.lean:16:7: error: none of the overloads is applicable
elab11.lean:16:7: error: none of the overloads are applicable
error for bla.f
invalid overload, expression
f 1

View file

@ -1,8 +1,8 @@
definition foo {A B : Type} [has_add A] (a : A) (b : B) : A :=
a
set_option trace.elaborator true
set_option trace.elaborator_detail true
-- set_option trace.elaborator true
-- set_option trace.elaborator_detail true
set_option pp.all true
#elab foo 0 1

View file

@ -1,8 +1,8 @@
boo.f 0 1 2 :
elab4.lean:13:6: error: none of the overloads is applicable
elab4.lean:13:6: error: none of the overloads are applicable
error for bla.f
failed to synthesize type class instance for
⊢ has_zero bool
⊢ has_add bool
error for foo.f
invalid function application, too many arguments, function type: