From 110b622b83f5fcaab691bd5371cf0f93f1c3c5c8 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 3 Jul 2014 20:41:42 -0700 Subject: [PATCH] =?UTF-8?q?feat(library/unifier):=20add=20support=20for=20?= =?UTF-8?q?unification=20constraints=20of=20the=20form=20"(elim=20...=20(?= =?UTF-8?q?=3Fm=20...))=20=3D=3F=3D=20t",=20where=20elim=20is=20an=20elimi?= =?UTF-8?q?nator?= Signed-off-by: Leonardo de Moura --- src/library/unifier.cpp | 138 +++++++++++++++++++++++++++++++++++----- tests/lean/run/uni.lean | 48 ++++++++++++++ 2 files changed, 171 insertions(+), 15 deletions(-) create mode 100644 tests/lean/run/uni.lean diff --git a/src/library/unifier.cpp b/src/library/unifier.cpp index d3216ae46c..de8605b094 100644 --- a/src/library/unifier.cpp +++ b/src/library/unifier.cpp @@ -16,6 +16,7 @@ Author: Leonardo de Moura #include "kernel/abstract.h" #include "kernel/instantiate.h" #include "kernel/type_checker.h" +#include "kernel/inductive/inductive.h" #include "library/occurs.h" #include "library/unifier.h" #include "library/kernel_bindings.h" @@ -479,9 +480,105 @@ struct unifier_fn { add_cnstr(c, mlvl_occs, mvar_occs, g_first_very_delayed); } + /** \brief Return true iff \c e is of the form (elim ... (?m ...)) */ + bool is_elim_meta_app(expr const & e) { + if (!is_app(e)) + return false; + expr const & f = get_app_fn(e); + if (!is_constant(f)) + return false; + auto it_name = inductive::is_elim_rule(m_env, const_name(f)); + if (!it_name) + return false; + if (!is_meta(app_arg(e))) + return false; + if (is_pi(m_tc.whnf(m_tc.infer(e)))) + return false; + return true; + } + + /** + \brief Given (elim args) =?= t, where elim is the eliminator/recursor for the inductive declaration \c decl, + and the last argument of args is of the form (?m ...), we create a case split where we try to assign (?m ...) + to the different constructors of decl. + */ + void mk_inductice_cnstrs(inductive::inductive_decl const & decl, expr const & elim, buffer & args, expr const & t, + justification const & j) { + lean_assert(is_constant(elim)); + levels elim_lvls = const_levels(elim); + unsigned elim_num_lvls = length(elim_lvls); + unsigned num_args = args.size(); + expr meta = args[num_args - 1]; // save last argument, we will update it + lean_assert(is_meta(meta)); + buffer margs; + expr const & m = get_app_args(meta, margs); + expr const & mtype = mlocal_type(m); + buffer alts; + for (auto const & intro : inductive::inductive_decl_intros(decl)) { + name const & intro_name = inductive::intro_rule_name(intro); + declaration intro_decl = m_env.get(intro_name); + levels intro_lvls; + if (length(intro_decl.get_univ_params()) == elim_num_lvls) { + intro_lvls = elim_lvls; + } else { + lean_assert(length(intro_decl.get_univ_params()) == elim_num_lvls - 1); + intro_lvls = tail(elim_lvls); + } + expr intro_fn = mk_constant(inductive::intro_rule_name(intro), intro_lvls); + expr hint = intro_fn; + expr intro_type = m_tc.whnf(inductive::intro_rule_type(intro)); + while (is_pi(intro_type)) { + hint = mk_app(hint, mk_app(mk_aux_metavar_for(mtype), margs)); + intro_type = m_tc.whnf(binding_body(intro_type)); + } + constraint c1 = mk_eq_cnstr(meta, hint, j); + args[num_args - 1] = hint; + expr reduce_elim = m_tc.whnf(mk_app(elim, args)); + constraint c2 = mk_eq_cnstr(reduce_elim, t, j); + alts.push_back(constraints({c1, c2})); + } + if (alts.empty()) { + set_conflict(j); + } else if (alts.size() == 1) { + process_constraints(alts[0], justification()); + } else { + justification a = mk_assumption_justification(m_next_assumption_idx); + add_case_split(std::unique_ptr(new ho_case_split(*this, to_list(alts.begin() + 1, alts.end())))); + process_constraints(alts[0], a); + } + } + + bool try_inductive_hint_core(expr const & t1, expr const & t2, justification const & j) { + if (!is_elim_meta_app(t1)) + return false; + buffer args; + expr const & elim = get_app_args(t1, args); + auto it_name = *inductive::is_elim_rule(m_env, const_name(elim)); + auto decls = *inductive::is_inductive_decl(m_env, it_name); + for (auto const & d : std::get<2>(decls)) { + if (inductive::inductive_decl_name(d) == it_name) { + mk_inductice_cnstrs(d, elim, args, t2, j); + return true; + } + } + lean_unreachable(); // LCOV_EXCL_LINE + } + + /** + \brief Try to solve constraint of the form (elim ... (?m ...)) =?= t, by assigning (?m ...) to the introduction rules + associated with the eliminator \c elim. + */ + bool try_inductive_hint(expr const & t1, expr const & t2, justification const & j) { + return + try_inductive_hint_core(t1, t2, j) || + try_inductive_hint_core(t2, t1, j); + } + bool is_def_eq(expr const & t1, expr const & t2, justification const & j) { if (m_tc.is_def_eq(t1, t2, j)) { return true; + } else if (try_inductive_hint(t1, t2, j)) { + return true; } else { set_conflict(j); return false; @@ -595,6 +692,12 @@ struct unifier_fn { rhs = m_tc.whnf(rhs); lhs = m_tc.whnf(lhs); + // We delay constraints where lhs or rhs are of the form (elim ... (?m ...)) + if (is_elim_meta_app(lhs) || is_elim_meta_app(rhs)) { + add_very_delayed_cnstr(c, &unassigned_lvls, &unassigned_exprs); + return true; + } + // If lhs or rhs were updated, then invoke is_def_eq again. if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c)) { // some metavariables were instantiated, try is_def_eq again @@ -1101,10 +1204,14 @@ struct unifier_fn { } void consume_tc_cnstrs() { - while (auto c = m_tc.next_cnstr()) { + while (true) { if (in_conflict()) return; - process_constraint(*c); + if (auto c = m_tc.next_cnstr()) { + process_constraint(*c); + } else { + break; + } } } @@ -1123,11 +1230,16 @@ struct unifier_fn { return process_plugin_constraint(c); } + /** \brief Return true if unifier may be able to produce more solutions */ + bool more_solutions() const { + return !in_conflict() || !m_case_splits.empty(); + } + /** \brief Produce the next solution */ optional next() { - if (in_conflict()) + if (!more_solutions()) return failure(); - if (!m_case_splits.empty()) { + if (!m_first && !m_case_splits.empty()) { justification all_assumptions; for (auto const & cs : m_case_splits) all_assumptions = mk_composite1(all_assumptions, mk_assumption_justification(cs->m_assumption_idx)); @@ -1162,7 +1274,7 @@ unifier_plugin get_noop_unifier_plugin() { } lazy_list unify(std::shared_ptr u) { - if (u->in_conflict()) { + if (!u->more_solutions()) { u->failure(); // make sure exception is thrown if u->m_use_exception is true return lazy_list(); } else { @@ -1192,17 +1304,13 @@ lazy_list unify(environment const & env, expr const & lhs, expr co type_checker tc(env, new_ngen.mk_child()); expr _lhs = s.instantiate(lhs); expr _rhs = s.instantiate(rhs); - if (!tc.is_def_eq(_lhs, _rhs)) + auto u = std::make_shared(env, 0, nullptr, ngen, s, p, false, max_steps); + if (!u->is_def_eq(_lhs, _rhs, justification()) && !u->more_solutions()) return lazy_list(); - buffer cs; - while (auto c = tc.next_cnstr()) { - cs.push_back(*c); - } - if (cs.empty()) { - return lazy_list(s); - } else { - return unify(std::make_shared(env, cs.size(), cs.data(), ngen, s, p, false, max_steps)); - } + u->consume_tc_cnstrs(); + if (!u->more_solutions()) + return lazy_list(); + return unify(u); } lazy_list unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen, diff --git a/tests/lean/run/uni.lean b/tests/lean/run/uni.lean new file mode 100644 index 0000000000..378842df0d --- /dev/null +++ b/tests/lean/run/uni.lean @@ -0,0 +1,48 @@ +import logic + +inductive nat : Type := +| zero : nat +| succ : nat → nat + +check @nat_rec + +(* +local env = get_env() +local nat_rec = Const("nat_rec", {1}) +local nat = Const("nat") +local n = Local("n", nat) +local C = Fun(n, Bool) +local p = Local("p", Bool) +local ff = Const("false") +local tt = Const("true") +local t = nat_rec(C, ff, Fun(n, p, tt)) +local zero = Const("zero") +local succ = Const("succ") +local one = succ(zero) +local tc = type_checker(env) +print(env:whnf(t(one))) +print(env:whnf(t(zero))) +local m = mk_metavar("m", nat) +print(env:whnf(t(m))) + +function test_unify(env, lhs, rhs, num_s) + print(tostring(lhs) .. " =?= " .. tostring(rhs) .. ", expected: " .. tostring(num_s)) + local ss = unify(env, lhs, rhs, name_generator(), substitution(), options()) + local n = 0 + for s in ss do + print("solution: ") + s:for_each_expr(function(n, v, j) + print(" " .. tostring(n) .. " := " .. tostring(v)) + end) + s:for_each_level(function(n, v, j) + print(" " .. tostring(n) .. " := " .. tostring(v)) + end) + n = n + 1 + end + if num_s ~= n then print("n: " .. n) end + assert(num_s == n) +end + +test_unify(env, t(m), tt, 1) +test_unify(env, t(m), ff, 1) +*) \ No newline at end of file