From 4ab14e709ee4b885fbdec1b80323a97a8d017ec8 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 3 Oct 2015 18:18:44 -0700 Subject: [PATCH] feat(library/blast): finish is_def_eq/unifier for blast tactic --- src/library/blast/expr.cpp | 7 ++ src/library/blast/expr.h | 1 + src/library/blast/infer_type.cpp | 159 ++++++++++++++++++++---- src/library/blast/state.cpp | 200 +++++++++++++++++++++++++++++++ src/library/blast/state.h | 71 +++++++++++ 5 files changed, 416 insertions(+), 22 deletions(-) diff --git a/src/library/blast/expr.cpp b/src/library/blast/expr.cpp index 416ded85a8..84d9f8288a 100644 --- a/src/library/blast/expr.cpp +++ b/src/library/blast/expr.cpp @@ -280,6 +280,13 @@ expr update_constant(expr const & e, levels const & new_levels) { return e; } +expr update_local(expr const & e, expr const & new_type) { + if (!is_eqp(mlocal_type(e), new_type)) + return blast::mk_local(mlocal_name(e), local_pp_name(e), new_type, local_info(e)); + else + return e; +} + expr update_macro(expr const & e, unsigned num, expr const * args) { if (num == macro_num_args(e)) { unsigned i = 0; diff --git a/src/library/blast/expr.h b/src/library/blast/expr.h index 98e6a7471b..256f2688d9 100644 --- a/src/library/blast/expr.h +++ b/src/library/blast/expr.h @@ -91,6 +91,7 @@ expr update_metavar(expr const & e, expr const & new_type); expr update_binding(expr const & e, expr const & new_domain, expr const & new_body); expr update_sort(expr const & e, level const & new_level); expr update_constant(expr const & e, levels const & new_levels); +expr update_local(expr const & e, expr const & new_type); expr update_macro(expr const & e, unsigned num, expr const * args); void initialize_expr(); diff --git a/src/library/blast/infer_type.cpp b/src/library/blast/infer_type.cpp index 9fbc99f0bf..ede0eea420 100644 --- a/src/library/blast/infer_type.cpp +++ b/src/library/blast/infer_type.cpp @@ -7,6 +7,7 @@ Author: Leonardo de Moura #include "util/interrupt.h" #include "kernel/instantiate.h" #include "kernel/abstract.h" +#include "kernel/for_each_fn.h" #include "library/normalize.h" #include "library/blast/infer_type.h" #include "library/blast/blast_context.h" @@ -267,8 +268,23 @@ bool is_def_eq(level const & l1, level const & l2) { if (is_equivalent(l1, l2)) { return true; } else { - // TODO(Leo): check/update universe level assignment - lean_unreachable(); + state & s = curr_state(); + if (is_uref(l1)) { + if (s.is_uref_assigned(l1)) { + return is_def_eq(*s.get_uref_assignment(l1), l2); + } else { + s.assign_uref(l1, l2); + return true; + } + } + if (is_uref(l2)) { + if (s.is_uref_assigned(l2)) { + return is_def_eq(l1, *s.get_uref_assignment(l2)); + } else { + s.assign_uref(l2, l1); + return true; + } + } return false; } } @@ -285,18 +301,105 @@ bool is_def_eq(levels const & ls1, levels const & ls2) { } } +static bool is_def_eq_core(expr const & e1, expr const & e2); + /** \brief Given \c e of the form ?m t_1 ... t_n, where ?m is an assigned mref, substitute \c ?m with its assignment. */ static expr subst_mref(expr const & e) { - // TODO(Leo): - lean_unreachable(); + buffer args; + expr const & u = get_app_args(e, args); + expr const * v = curr_state().get_mref_assignment(u); + lean_assert(v); + return apply_beta(*v, args.size(), args.data()); } -/** \brief Given \c m of the form ?m t_1 ... t_n, (try to) assign +/** \brief Given \c ma of the form ?m t_1 ... t_n, (try to) assign ?m to (an abstraction of) v. Return true if success and false otherwise. */ -static bool assign_mref(expr const & m, expr const & v) { - // TODO(Leo): - lean_unreachable(); +static bool assign_mref_core(expr const & ma, expr const & v) { + buffer args; + expr const & m = get_app_args(ma, args); + buffer locals; + for (expr const & arg : args) { + if (!blast::is_local(arg)) + break; // is not local + if (std::any_of(locals.begin(), locals.end(), [&](expr const & local) { return local_index(local) == local_index(arg); })) + break; // duplicate local + locals.push_back(arg); + } + lean_assert(is_mref(m)); + state & s = curr_state(); + metavar_decl const * d = s.get_metavar_decl(m); + lean_assert(d); + expr new_v = s.instantiate_urefs_mrefs(v); + // We must check + // 1. All href in new_v are in the context of m. + // 2. The context of any (unassigned) mref in new_v must be a subset of the context of m. + // If it is not we force it to be. + // 3. Any local constant occurring in new_v occurs in locals + // 4. m does not occur in new_v + bool ok = true; + for_each(v, [&](expr const & e, unsigned) { + if (!ok) + return false; // stop search + if (is_href(e)) { + if (!d->contains_href(e)) { + ok = false; // failed 1 + return false; + } + } else if (blast::is_local(e)) { + if (std::all_of(locals.begin(), locals.end(), [&](expr const & a) { return local_index(a) != local_index(e); })) { + ok = false; // failed 3 + return false; + } + } else if (is_mref(e)) { + if (m == e) { + ok = false; // failed 4 + return false; + } + s.restrict_mref_context_using(e, m); // enforce 2 + return false; + } + return true; + }); + if (!ok) + return false; + if (args.empty()) { + // easy case + s.assign_mref(m, new_v); + return true; + } else if (args.size() == locals.size()) { + s.assign_mref(m, Fun(locals, v)); + return true; + } else { + // This case is imprecise since it is not a higher order pattern. + // That the term \c ma is of the form (?m t_1 ... t_n) and the t_i's are not pairwise + // distinct local constants. + expr m_type = d->get_type(); + for (unsigned i = 0; i < args.size(); i++) { + m_type = whnf(m_type); + if (!is_pi(m_type)) + return false; + lean_assert(i <= locals.size()); + if (i == locals.size()) + locals.push_back(blast::mk_local(mk_fresh_local_name(), binding_name(m_type), binding_domain(m_type), binding_info(m_type))); + lean_assert(i < locals.size()); + m_type = instantiate(binding_body(m_type), locals[i]); + } + lean_assert(locals.size() == args.size()); + s.assign_mref(m, Fun(locals, v)); + return true; + } +} + +/** \brief Given \c ma of the form ?m t_1 ... t_n, (try to) assign + ?m to (an abstraction of) v. Return true if success and false otherwise. */ +static bool assign_mref(expr const & ma, expr const & v) { + if (assign_mref_core(ma, v)) + return true; + expr const & f = get_app_fn(v); + if (is_mref(f) && curr_state().is_mref_assigned(f)) + return assign_mref_core(v, ma); + return false; } static bool is_def_eq_binding(expr e1, expr e2) { @@ -309,7 +412,7 @@ static bool is_def_eq_binding(expr e1, expr e2) { if (binding_domain(e1) != binding_domain(e2)) { var_e1_type = instantiate_rev(binding_domain(e1), subst.size(), subst.data()); expr var_e2_type = instantiate_rev(binding_domain(e2), subst.size(), subst.data()); - if (!is_def_eq(var_e2_type, *var_e1_type)) + if (!is_def_eq_core(var_e2_type, *var_e1_type)) return false; } if (!closed(binding_body(e1)) || !closed(binding_body(e2))) { @@ -325,8 +428,8 @@ static bool is_def_eq_binding(expr e1, expr e2) { e1 = binding_body(e1); e2 = binding_body(e2); } while (e1.kind() == k && e2.kind() == k); - return is_def_eq(instantiate_rev(e1, subst.size(), subst.data()), - instantiate_rev(e2, subst.size(), subst.data())); + return is_def_eq_core(instantiate_rev(e1, subst.size(), subst.data()), + instantiate_rev(e2, subst.size(), subst.data())); } static bool is_def_eq_app(expr const & e1, expr const & e2) { @@ -334,10 +437,10 @@ static bool is_def_eq_app(expr const & e1, expr const & e2) { buffer args1, args2; expr const & f1 = get_app_args(e1, args1); expr const & f2 = get_app_args(e2, args2); - if (args1.size() != args2.size() || !is_def_eq(f1, f2)) + if (args1.size() != args2.size() || !is_def_eq_core(f1, f2)) return false; for (unsigned i = 0; i < args1.size(); i++) { - if (!is_def_eq(args1[i], args2[i])) + if (!is_def_eq_core(args1[i], args2[i])) return false; } return true; @@ -347,7 +450,7 @@ static bool is_def_eq_eta(expr const & e1, expr const & e2) { expr new_e1 = try_eta(e1); expr new_e2 = try_eta(e2); if (e1 != new_e1 || e2 != new_e2) - return is_def_eq(new_e1, new_e2); + return is_def_eq_core(new_e1, new_e2); return false; } @@ -356,26 +459,25 @@ static bool is_def_eq_proof_irrel(expr const & e1, expr const & e2) { return false; expr e1_type = infer_type(e1); expr e2_type = infer_type(e2); - return is_prop(e1_type) && is_def_eq(e1_type, e2_type); + return is_prop(e1_type) && is_def_eq_core(e1_type, e2_type); } -bool is_def_eq(expr const & e1, expr const & e2) { +static bool is_def_eq_core(expr const & e1, expr const & e2) { check_system("is_def_eq"); if (e1 == e2) return true; - state & s = curr_state(); expr const & f1 = get_app_fn(e1); if (is_mref(f1)) { - if (s.is_mref_assigned(f1)) { - return is_def_eq(subst_mref(e1), e2); + if (curr_state().is_mref_assigned(f1)) { + return is_def_eq_core(subst_mref(e1), e2); } else { return assign_mref(e1, e2); } } expr const & f2 = get_app_fn(e2); if (is_mref(f2)) { - if (s.is_mref_assigned(f2)) { - return is_def_eq(e1, subst_mref(e2)); + if (curr_state().is_mref_assigned(f2)) { + return is_def_eq_core(e1, subst_mref(e2)); } else { return assign_mref(e2, e1); } @@ -383,7 +485,7 @@ bool is_def_eq(expr const & e1, expr const & e2) { expr e1_n = whnf(e1); expr e2_n = whnf(e2); if (e1 != e1_n || e2 != e2_n) - return is_def_eq(e1_n, e2_n); + return is_def_eq_core(e1_n, e2_n); if (e1.kind() == e2.kind()) { switch (e1.kind()) { case expr_kind::Lambda: @@ -416,4 +518,17 @@ bool is_def_eq(expr const & e1, expr const & e2) { return true; return is_def_eq_proof_irrel(e1, e2); } + +/** \remark Precision of is_def_eq can be improved if mrefs and urefs in e1 and e2 are instantiated + before we invoke is_def_eq */ +bool is_def_eq(expr const & e1, expr const & e2) { + if (e1 == e2) + return true; // quick check + state & s = curr_state(); + state::assignment_snapshot saved(s); + bool r = is_def_eq_core(e1, e2); + if (!r) + saved.restore(); + return r; +} }} diff --git a/src/library/blast/state.cpp b/src/library/blast/state.cpp index e9d8c8ca6f..879d4583f2 100644 --- a/src/library/blast/state.cpp +++ b/src/library/blast/state.cpp @@ -4,13 +4,31 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ +#include "kernel/instantiate.h" #include "kernel/abstract.h" #include "kernel/for_each_fn.h" #include "kernel/replace_fn.h" +#include "library/replace_visitor.h" #include "library/blast/state.h" namespace lean { namespace blast { +bool metavar_decl::restrict_context_using(metavar_decl const & other) { + buffer new_ctx; + bool modified = false; + for (unsigned href : m_context) { + if (other.m_context_as_set.contains(href)) { + new_ctx.push_back(href); + } else { + modified = true; + m_context_as_set.erase(href); + } + } + if (modified) + m_context = to_list(new_ctx); + return modified; +} + state::state():m_next_uref_index(0), m_next_mref_index(0) {} /** \brief Mark that hypothesis h with index hidx is fixed by the meta-variable midx. @@ -61,6 +79,16 @@ expr state::mk_metavar(expr const & type) { return state::mk_metavar(ctx, type); } +void state::restrict_mref_context_using(expr const & mref1, expr const & mref2) { + metavar_decl const * d1 = m_metavar_decls.find(mref_index(mref1)); + metavar_decl const * d2 = m_metavar_decls.find(mref_index(mref2)); + lean_assert(d1); + lean_assert(d2); + metavar_decl new_d1(*d1); + if (new_d1.restrict_context_using(*d2)) + m_metavar_decls.insert(mref_index(mref1), new_d1); +} + goal state::to_goal(branch const & b) const { hypothesis_idx_map hidx2local; metavar_idx_map midx2meta; @@ -122,6 +150,178 @@ void state::display(environment const & env, io_state const & ios) const { ios.get_diagnostic_channel() << mk_pair(to_goal().pp(fmt), ios.get_options()); } +bool state::has_assigned_uref(level const & l) const { + if (!has_meta(l)) + return false; + if (m_uassignment.empty()) + return false; + bool found = false; + for_each(l, [&](level const & l) { + if (!has_meta(l)) + return false; // stop search + if (found) + return false; // stop search + if (is_uref(l) && is_uref_assigned(l)) { + found = true; + return false; // stop search + } + return true; // continue search + }); + return found; +} + +bool state::has_assigned_uref(levels const & ls) const { + for (level const & l : ls) { + if (has_assigned_uref(l)) + return true; + } + return false; +} + +bool state::has_assigned_uref_mref(expr const & e) const { + if (!has_mref(e) && !has_univ_metavar(e)) + return false; + if (m_eassignment.empty() && m_uassignment.empty()) + return false; + bool found = false; + for_each(e, [&](expr const & e, unsigned) { + if (!has_mref(e) && !has_univ_metavar(e)) + return false; // stop search + if (found) + return false; // stop search + if ((is_mref(e) && is_mref_assigned(e)) || + (is_constant(e) && has_assigned_uref(const_levels(e))) || + (is_sort(e) && has_assigned_uref(sort_level(e)))) { + found = true; + return false; // stop search + } + return true; // continue search + }); + return found; +} + +level state::instantiate_urefs(level const & l) { + if (!has_assigned_uref(l)) + return l; + return replace(l, [&](level const & l) { + if (!has_meta(l)) { + return some_level(l); + } else if (is_uref(l)) { + level const * v1 = get_uref_assignment(l); + if (v1) { + level v2 = instantiate_urefs(*v1); + if (*v1 != v2) { + assign_uref(l, v2); + return some_level(v2); + } else { + return some_level(*v1); + } + } + } + return none_level(); + }); +} + +struct instantiate_urefs_mrefs_fn : public replace_visitor { + state & m_state; + + level visit_level(level const & l) { + return m_state.instantiate_urefs(l); + } + + levels visit_levels(levels const & ls) { + return map_reuse(ls, + [&](level const & l) { return visit_level(l); }, + [](level const & l1, level const & l2) { return is_eqp(l1, l2); }); + } + + virtual expr visit_sort(expr const & s) { + return blast::update_sort(s, visit_level(sort_level(s))); + } + + virtual expr visit_constant(expr const & c) { + return blast::update_constant(c, visit_levels(const_levels(c))); + } + + virtual expr visit_local(expr const & e) { + if (blast::is_local(e)) { + return blast::update_local(e, visit(mlocal_type(e))); + } else { + lean_assert(is_href(e)); + return e; + } + } + + virtual expr visit_meta(expr const & m) { + lean_assert(is_mref(m)); + if (auto v1 = m_state.get_mref_assignment(m)) { + if (!has_mref(*v1)) { + return *v1; + } else { + expr v2 = m_state.instantiate_urefs_mrefs(*v1); + if (v2 != *v1) + m_state.assign_mref(m, v2); + return v2; + } + } else { + return m; + } + } + + virtual expr visit_app(expr const & e) { + buffer args; + expr const & f = get_app_rev_args(e, args); + if (is_mref(f)) { + if (auto v = m_state.get_mref_assignment(f)) { + expr new_app = apply_beta(*v, args.size(), args.data()); + if (has_mref(new_app)) + return visit(new_app); + else + return new_app; + } + } + expr new_f = visit(f); + buffer new_args; + bool modified = !is_eqp(new_f, f); + for (expr const & arg : args) { + expr new_arg = visit(arg); + if (!is_eqp(arg, new_arg)) + modified = true; + new_args.push_back(new_arg); + } + if (!modified) + return e; + else + return mk_rev_app(new_f, new_args, e.get_tag()); + } + + virtual expr visit_macro(expr const & e) { + lean_assert(is_macro(e)); + buffer new_args; + for (unsigned i = 0; i < macro_num_args(e); i++) + new_args.push_back(visit(macro_arg(e, i))); + return blast::update_macro(e, new_args.size(), new_args.data()); + } + + virtual expr visit(expr const & e) { + if (!has_mref(e) || !has_univ_metavar(e)) + return e; + else + return replace_visitor::visit(e); + } + +public: + instantiate_urefs_mrefs_fn(state & s):m_state(s) {} + expr operator()(expr const & e) { return visit(e); } +}; + +expr state::instantiate_urefs_mrefs(expr const & e) { + if (!has_assigned_uref_mref(e)) + return e; + else + return instantiate_urefs_mrefs_fn(*this)(e); +} + #ifdef LEAN_DEBUG bool state::check_hypothesis(expr const & e, branch const & b, unsigned hidx, hypothesis const & h) const { lean_assert(closed(e)); diff --git a/src/library/blast/state.h b/src/library/blast/state.h index f00f6f0e40..afa731aab5 100644 --- a/src/library/blast/state.h +++ b/src/library/blast/state.h @@ -22,7 +22,14 @@ public: metavar_decl(hypothesis_idx_list const & c, hypothesis_idx_set const & s, expr const & t): m_context(c), m_context_as_set(s), m_type(t) {} hypothesis_idx_list get_context() const { return m_context; } + /** \brief Return true iff \c h is in the context of the this metavar declaration */ + bool contains_href(expr const & h) const { + return m_context_as_set.contains(href_index(h)); + } expr const & get_type() const { return m_type; } + /** \brief Make sure the declaration context of this declaration is a subset of \c other. + \remark Return true iff the context has been modified. */ + bool restrict_context_using(metavar_decl const & other); }; class state { @@ -59,8 +66,29 @@ class state { #endif public: state(); + level mk_uref(); + bool is_uref_assigned(level const & l) const { + lean_assert(is_uref(l)); + return m_uassignment.contains(uref_index(l)); + } + + // u := l + void assign_uref(level const & u, level const & l) { + lean_assert(!is_uref_assigned(u)); + m_uassignment.insert(uref_index(u), l); + } + + level const * get_uref_assignment(level const & l) const { + lean_assert(is_uref_assigned(l)); + return m_uassignment.find(uref_index(l)); + } + + /** \brief Instantiate any assigned uref in \c l with its assignment. + \remark This is not a const method because it may normalize the assignment. */ + level instantiate_urefs(level const & l); + /** \brief Create a new metavariable using the given type and context. \pre ctx must be a subset of the hypotheses in the main branch. */ expr mk_metavar(hypothesis_idx_buffer const & ctx, expr const & type); @@ -68,11 +96,37 @@ public: The context of this metavariable will be all hypotheses occurring in the main branch. */ expr mk_metavar(expr const & type); + /** \brief Make sure the metavariable declaration context of mref1 is a + subset of the metavariable declaration context of mref2. */ + void restrict_mref_context_using(expr const & mref1, expr const & mref2); + bool is_mref_assigned(expr const & e) const { lean_assert(is_mref(e)); return m_eassignment.contains(mref_index(e)); } + /** \brief Return true iff \c l contains an assigned uref */ + bool has_assigned_uref(level const & l) const; + bool has_assigned_uref(levels const & ls) const; + + expr const * get_mref_assignment(expr const & e) const { + lean_assert(is_mref(e)); + return m_eassignment.find(mref_index(e)); + } + + // m := e + void assign_mref(expr const & m, expr const & e) { + lean_assert(!is_mref_assigned(m)); + m_eassignment.insert(mref_index(m), e); + } + + /** \brief Return true if \c e contains an assigned mref or uref */ + bool has_assigned_uref_mref(expr const & e) const; + + /** \brief Instantiate any assigned mref in \c e with its assignment. + \remark This is not a const method because it may normalize the assignment. */ + expr instantiate_urefs_mrefs(expr const & e); + /** \brief Add a new hypothesis to the main branch */ expr add_hypothesis(name const & n, expr const & type, optional const & value, optional const & jst) { return m_main.add_hypothesis(n, type, value, jst); @@ -102,6 +156,23 @@ public: void display(environment const & env, io_state const & ios) const; + /** Auxiliary object for creating snapshots of the metavariable assignments. + \remark The snapshots are created (and restored) in constant time */ + class assignment_snapshot { + state & m_state; + uassignment m_old_uassignment; + eassignment m_old_eassignment; + public: + assignment_snapshot(state & s): + m_state(s), + m_old_uassignment(s.m_uassignment), + m_old_eassignment(s.m_eassignment) {} + void restore() { + m_state.m_uassignment = m_old_uassignment; + m_state.m_eassignment = m_old_eassignment; + } + }; + #ifdef LEAN_DEBUG bool check_invariant() const; #endif