From d05e5422f9a5fba7b75dc4ff78bf0865aae4ad5b Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 29 Jul 2016 10:36:14 -0700 Subject: [PATCH] refactor(frontends/lean/elaborator): snapshots --- src/frontends/lean/elaborator.cpp | 125 +++++++++++++++++----------- src/frontends/lean/elaborator.h | 35 ++++---- tests/lean/elab11.lean.expected.out | 2 +- tests/lean/elab2.lean | 4 +- tests/lean/elab4.lean.expected.out | 4 +- 5 files changed, 103 insertions(+), 67 deletions(-) diff --git a/src/frontends/lean/elaborator.cpp b/src/frontends/lean/elaborator.cpp index 5eb58ceb46..07091e6004 100644 --- a/src/frontends/lean/elaborator.cpp +++ b/src/frontends/lean/elaborator.cpp @@ -85,13 +85,13 @@ static std::string pos_string_for(expr const & e) { level elaborator::mk_univ_metavar() { level r = m_ctx.mk_univ_metavar_decl(); - m_uvar_stack.push_back(r); + m_uvar_stack = cons(r, m_uvar_stack); return r; } expr elaborator::mk_metavar(expr const & A) { expr r = copy_tag(A, m_ctx.mk_metavar_decl(m_ctx.lctx(), A)); - m_mvar_stack.push_back(r); + m_mvar_stack = cons(r, m_mvar_stack); return r; } @@ -116,7 +116,7 @@ expr elaborator::mk_instance_core(expr const & C) { expr elaborator::mk_instance(expr const & C) { if (has_expr_metavar(C)) { expr inst = mk_metavar(C); - m_instance_stack.push_back(inst); + m_instance_stack = cons(inst, m_instance_stack); return inst; } else { return mk_instance_core(C); @@ -345,10 +345,10 @@ expr elaborator::visit_prenum(expr const & e, optional const & expected_ty if (expected_type) { A = *expected_type; if (is_metavar(*expected_type)) - m_numeral_type_stack.push_back(A); + m_numeral_type_stack = cons(A, m_numeral_type_stack); } else { A = mk_type_metavar(); - m_numeral_type_stack.push_back(A); + m_numeral_type_stack = cons(A, m_numeral_type_stack); } level A_lvl = get_level(A); levels ls(A_lvl); @@ -688,13 +688,11 @@ expr elaborator::visit_default_app_core(expr const & fn, arg_mask amask, buffer< if (auto new_r = ensure_has_type(r, instantiate_mvars(type), *expected_type)) { return *new_r; } else { - throw elaborator_exception(ref, format("invalid application") + pp_indent(r) + - line() + format("has type") + pp_indent(type) + - line() + format("but is expected to have type") + pp_indent(*expected_type)); + /* We do not generate the error here because we can produce a better one from + the caller (i.e., place the set the expected_type) */ } - } else { - return r; } + return r; } expr elaborator::visit_default_app(expr const & fn, arg_mask amask, buffer const & args, @@ -711,23 +709,36 @@ expr elaborator::visit_overloaded_app(buffer const & fns, buffer con optional const & expected_type, expr const & ref) { trace_elab_detail(tout() << "overloaded application at " << pos_string_for(ref); tout() << pp_overloads(fns) << "\n";); - unsigned initial_inst_stack_sz = m_instance_stack.size(); + list saved_instance_stack = m_instance_stack; buffer new_args; for (expr const & arg : args) { new_args.push_back(visit(arg, none_expr())); } - metavar_context mctx = m_ctx.mctx(); - buffer> candidates; + snapshot S(*this); + + buffer> candidates; buffer error_msgs; for (expr const & fn : fns) { try { - checkpoint C(*this); - m_ctx.set_mctx(mctx); + // Restore state + S.restore(*this); + expr c = visit_overload_candidate(fn, new_args, expected_type, ref); - try_to_synthesize_type_class_instances(initial_inst_stack_sz); - candidates.emplace_back(c, m_ctx.mctx()); - C.commit(); + try_to_synthesize_type_class_instances(saved_instance_stack); + + if (expected_type) { + expr c_type = infer_type(c); + if (ensure_has_type(c, c_type, *expected_type)) { + candidates.emplace_back(c, snapshot(*this)); + } else { + throw elaborator_exception(ref, format("invalid overload, expression") + pp_indent(c) + + line() + format("has type") + pp_indent(c_type) + + line() + format("but is expected to have type") + pp_indent(*expected_type)); + } + } else { + candidates.emplace_back(c, snapshot(*this)); + } } catch (elaborator_exception & ex) { error_msgs.push_back(ex); } catch (exception & ex) { @@ -737,7 +748,9 @@ expr elaborator::visit_overloaded_app(buffer const & fns, buffer con lean_assert(candidates.size() + error_msgs.size() == fns.size()); if (candidates.empty()) { - format r("none of the overloads is applicable"); + S.restore(*this); + + format r("none of the overloads are applicable"); lean_assert(error_msgs.size() == fns.size()); for (unsigned i = 0; i < fns.size(); i++) { if (i > 0) r += line(); @@ -747,6 +760,8 @@ expr elaborator::visit_overloaded_app(buffer const & fns, buffer con } throw elaborator_exception(ref, r); } else if (candidates.size() > 1) { + S.restore(*this); + options new_opts = m_opts.update_if_undef(get_pp_full_names_name(), true); flet set_opts(m_opts, new_opts); format r("ambiguous overload, possible interpretations"); @@ -755,7 +770,8 @@ expr elaborator::visit_overloaded_app(buffer const & fns, buffer con } throw elaborator_exception(ref, r); } else { - m_ctx.set_mctx(candidates[0].second); + // Restore successful state + candidates[0].second.restore(*this); return candidates[0].first; } } @@ -917,7 +933,7 @@ expr elaborator::visit_let(expr const & e, optional const & expected_type) lean_unreachable(); } -expr elaborator::visit_placeholder(expr const & e, optional const & expected_type) { +expr elaborator::visit_placeholder(expr const &, optional const & expected_type) { if (expected_type) return mk_metavar(*expected_type); else @@ -980,23 +996,29 @@ expr elaborator::get_default_numeral_type() { } void elaborator::ensure_numeral_types_assigned(checkpoint const & C) { - for (unsigned i = C.m_numeral_type_stack_sz; i < m_numeral_type_stack.size(); i++) { - expr A = instantiate_mvars(m_numeral_type_stack[i]); + list old_stack = C.m_saved_numeral_type_stack; + while (!is_eqp(m_numeral_type_stack, old_stack)) { + lean_assert(m_numeral_type_stack); + expr A = instantiate_mvars(head(m_numeral_type_stack)); if (is_metavar(A)) { if (!assign_mvar(A, get_default_numeral_type())) throw elaborator_exception(A, format("invalid numeral, failed to force numeral to be a nat")); } + m_numeral_type_stack = tail(m_numeral_type_stack); } } -void elaborator::synthesize_type_class_instances_core(unsigned old_sz, bool force) { - unsigned j = old_sz; - for (unsigned i = old_sz; i < m_instance_stack.size(); i++) { - lean_assert(is_metavar(m_instance_stack[i])); - metavar_decl mdecl = *m_ctx.mctx().get_metavar_decl(m_instance_stack[i]); - expr inst = instantiate_mvars(m_instance_stack[i]); +void elaborator::synthesize_type_class_instances_core(list const & old_stack, bool force) { + buffer to_keep; + while (!is_eqp(m_instance_stack, old_stack)) { + lean_assert(m_instance_stack); + lean_assert(is_metavar(head(m_instance_stack))); + expr mvar = head(m_instance_stack); + metavar_decl mdecl = *m_ctx.mctx().get_metavar_decl(mvar); + expr inst = instantiate_mvars(mvar); + m_instance_stack = tail(m_instance_stack); if (!has_expr_metavar(inst)) { - trace_elab(tout() << "skipping type class resolution at " << pos_string_for(m_instance_stack[i]) + trace_elab(tout() << "skipping type class resolution at " << pos_string_for(mvar) << ", placeholder instantiated using type inference\n";); continue; } @@ -1004,21 +1026,20 @@ void elaborator::synthesize_type_class_instances_core(unsigned old_sz, bool forc if (!has_expr_metavar(inst_type)) { // We must try to synthesize instance using the local context where it was declared if (!is_def_eq(inst, mk_instance_core(mdecl.get_context(), inst_type))) - throw elaborator_exception(m_instance_stack[i], + throw elaborator_exception(mvar, format("failed to assign type class instance to placeholder")); } else { if (force) { - throw elaborator_exception(m_instance_stack[i], + throw elaborator_exception(mvar, format("type class instance cannot be synthesized, type has metavariables") + pp_indent(inst_type)); } else { - m_instance_stack[j] = m_instance_stack[i]; - j++; + to_keep.push_back(mvar); } } } - if (!force) - m_instance_stack.shrink(j); + for (expr const & mvar : to_keep) + m_instance_stack = cons(mvar, m_instance_stack); } void elaborator::invoke_tactics(checkpoint const & C) { @@ -1026,30 +1047,40 @@ void elaborator::invoke_tactics(checkpoint const & C) { } void elaborator::ensure_no_unassigned_metavars(checkpoint const & C) { + // TODO(Leo) } -void elaborator::process_checkpoint(checkpoint const & C) { +void elaborator::process_checkpoint(checkpoint & C) { ensure_numeral_types_assigned(C); synthesize_type_class_instances(C); invoke_tactics(C); ensure_no_unassigned_metavars(C); + C.commit(); +} + +elaborator::snapshot::snapshot(elaborator & e): + m_saved_mctx(e.m_ctx.mctx()), + m_saved_uvar_stack(e.m_uvar_stack), + m_saved_mvar_stack(e.m_mvar_stack), + m_saved_instance_stack(e.m_instance_stack), + m_saved_numeral_type_stack(e.m_numeral_type_stack) {} + +void elaborator::snapshot::restore(elaborator & e) { + e.m_ctx.set_mctx(m_saved_mctx); + e.m_uvar_stack = m_saved_uvar_stack; + e.m_mvar_stack = m_saved_mvar_stack; + e.m_instance_stack = m_saved_instance_stack; + e.m_numeral_type_stack = m_saved_numeral_type_stack; } elaborator::checkpoint::checkpoint(elaborator & e): + snapshot(e), m_elaborator(e), - m_commit(false), - m_uvar_stack_sz(e.m_uvar_stack.size()), - m_mvar_stack_sz(e.m_mvar_stack.size()), - m_instance_stack_sz(e.m_instance_stack.size()), - m_numeral_type_stack_sz(e.m_numeral_type_stack.size()) { -} + m_commit(false) {} elaborator::checkpoint::~checkpoint() { if (!m_commit) { - m_elaborator.m_uvar_stack.shrink(m_uvar_stack_sz); - m_elaborator.m_mvar_stack.shrink(m_mvar_stack_sz); - m_elaborator.m_instance_stack.shrink(m_instance_stack_sz); - m_elaborator.m_numeral_type_stack.shrink(m_numeral_type_stack_sz); + restore(m_elaborator); } } diff --git a/src/frontends/lean/elaborator.h b/src/frontends/lean/elaborator.h index 75cefeb70e..f18620981b 100644 --- a/src/frontends/lean/elaborator.h +++ b/src/frontends/lean/elaborator.h @@ -24,18 +24,24 @@ class elaborator { local_level_decls m_local_level_decls; type_context m_ctx; - buffer m_uvar_stack; - buffer m_mvar_stack; - buffer m_instance_stack; - buffer m_numeral_type_stack; + list m_uvar_stack; + list m_mvar_stack; + list m_instance_stack; + list m_numeral_type_stack; - struct checkpoint { + struct snapshot { + metavar_context m_saved_mctx; + list m_saved_uvar_stack; + list m_saved_mvar_stack; + list m_saved_instance_stack; + list m_saved_numeral_type_stack; + snapshot(elaborator & elab); + void restore(elaborator & elab); + }; + + struct checkpoint : public snapshot { elaborator & m_elaborator; bool m_commit; - unsigned m_uvar_stack_sz; - unsigned m_mvar_stack_sz; - unsigned m_instance_stack_sz; - unsigned m_numeral_type_stack_sz; checkpoint(elaborator & e); ~checkpoint(); void commit(); @@ -84,7 +90,6 @@ class elaborator { expr instantiate_mvars(expr const & e); bool is_uvar_assigned(level const & l) const { return m_ctx.is_assigned(l); } bool is_mvar_assigned(expr const & e) const { return m_ctx.is_assigned(e); } - void resolve_instances_from(unsigned old_sz); level mk_univ_metavar(); expr mk_metavar(expr const & A); @@ -143,16 +148,16 @@ class elaborator { expr visit(expr const & e, optional const & expected_type); void ensure_numeral_types_assigned(checkpoint const & C); - void synthesize_type_class_instances_core(unsigned old_sz, bool force); - void try_to_synthesize_type_class_instances(unsigned old_sz) { - synthesize_type_class_instances_core(old_sz, false); + void synthesize_type_class_instances_core(list const & old_stack, bool force); + void try_to_synthesize_type_class_instances(list const & old_stack) { + synthesize_type_class_instances_core(old_stack, false); } void synthesize_type_class_instances(checkpoint const & C) { - synthesize_type_class_instances_core(C.m_instance_stack_sz, true); + synthesize_type_class_instances_core(C.m_saved_instance_stack, true); } void invoke_tactics(checkpoint const & C); void ensure_no_unassigned_metavars(checkpoint const & C); - void process_checkpoint(checkpoint const & C); + void process_checkpoint(checkpoint & C); public: elaborator(environment const & env, options const & opts, local_level_decls const & lls, diff --git a/tests/lean/elab11.lean.expected.out b/tests/lean/elab11.lean.expected.out index 2264f8d062..1133b4b260 100644 --- a/tests/lean/elab11.lean.expected.out +++ b/tests/lean/elab11.lean.expected.out @@ -3,7 +3,7 @@ elab11.lean:6:6: error: ambiguous overload, possible interpretations boo.f 1 bla.f 1 : ℕ boo.f 1 : bool -elab11.lean:16:7: error: none of the overloads is applicable +elab11.lean:16:7: error: none of the overloads are applicable error for bla.f invalid overload, expression f 1 diff --git a/tests/lean/elab2.lean b/tests/lean/elab2.lean index 322ba8a76a..87a151a41e 100644 --- a/tests/lean/elab2.lean +++ b/tests/lean/elab2.lean @@ -1,8 +1,8 @@ definition foo {A B : Type} [has_add A] (a : A) (b : B) : A := a -set_option trace.elaborator true -set_option trace.elaborator_detail true +-- set_option trace.elaborator true +-- set_option trace.elaborator_detail true set_option pp.all true #elab foo 0 1 diff --git a/tests/lean/elab4.lean.expected.out b/tests/lean/elab4.lean.expected.out index b62b4abfe8..59acbd807b 100644 --- a/tests/lean/elab4.lean.expected.out +++ b/tests/lean/elab4.lean.expected.out @@ -1,8 +1,8 @@ boo.f 0 1 2 : ℕ -elab4.lean:13:6: error: none of the overloads is applicable +elab4.lean:13:6: error: none of the overloads are applicable error for bla.f failed to synthesize type class instance for -⊢ has_zero bool +⊢ has_add bool error for foo.f invalid function application, too many arguments, function type: