diff --git a/src/library/congr_lemma_manager.cpp b/src/library/congr_lemma_manager.cpp index 6153c01e27..8ace8b2665 100644 --- a/src/library/congr_lemma_manager.cpp +++ b/src/library/congr_lemma_manager.cpp @@ -8,6 +8,7 @@ Author: Leonardo de Moura #include "kernel/abstract.h" #include "library/util.h" #include "library/locals.h" +#include "library/constants.h" #include "library/replace_visitor.h" #include "library/congr_lemma_manager.h" @@ -35,6 +36,7 @@ class congr_lemma_manager::imp { } }; + std::unordered_map m_simp_cache; std::unordered_map m_cache; expr infer(expr const & e) { return m_ctx.infer(e); } @@ -192,6 +194,100 @@ class congr_lemma_manager::imp { } } + optional mk_congr(expr const & fn, optional const & simp_lemma, + buffer const & pinfos, buffer const & kinds) { + try { + expr fn_type1 = whnf(infer(fn)); + expr fn_type2 = fn_type1; + name e_name("e"); + buffer lhss; + buffer rhss; // it contains the right-hand-side argument + buffer> eqs; // for Eq args, it contains the equality + buffer hyps; // contains lhss + rhss + eqs + buffer simp_lemma_args; + for (unsigned i = 0; i < pinfos.size(); i++) { + if (!is_pi(fn_type1)) { + return optional(); + } + expr lhs = m_ctx.mk_tmp_local(binding_name(fn_type1), binding_domain(fn_type1)); + expr rhs; + lhss.push_back(lhs); + hyps.push_back(lhs); + simp_lemma_args.push_back(lhs); + switch (kinds[i]) { + case congr_arg_kind::Eq: { + lean_assert(m_ctx.is_def_eq(binding_domain(fn_type1), binding_domain(fn_type2))); + rhs = m_ctx.mk_tmp_local(binding_name(fn_type2), binding_domain(fn_type2)); + expr eq_type = m_builder.mk_eq(lhs, rhs); + rhss.push_back(rhs); + expr eq = m_ctx.mk_tmp_local(e_name.append_after(eqs.size()+1), eq_type); + eqs.push_back(some_expr(eq)); + hyps.push_back(rhs); + hyps.push_back(eq); + simp_lemma_args.push_back(rhs); + simp_lemma_args.push_back(eq); + break; + } + case congr_arg_kind::Fixed: + rhs = lhs; + rhss.push_back(rhs); + eqs.push_back(none_expr()); + break; + case congr_arg_kind::Cast: { + rhs = m_ctx.mk_tmp_local(binding_name(fn_type2), binding_domain(fn_type2)); + rhss.push_back(rhs); + eqs.push_back(none_expr()); + hyps.push_back(rhs); + break; + }} + fn_type1 = whnf(instantiate(binding_body(fn_type1), lhs)); + fn_type2 = whnf(instantiate(binding_body(fn_type2), rhs)); + } + expr pr1 = mk_app(simp_lemma->get_proof(), simp_lemma_args); + expr type1 = simp_lemma->get_type(); + while (is_pi(type1)) + type1 = binding_body(type1); + type1 = instantiate_rev(type1, simp_lemma_args.size(), simp_lemma_args.data()); + expr lhs1, rhs1; + lean_verify(is_eq(type1, lhs1, rhs1)); + // build proof2 + expr rhs2 = mk_app(fn, rhss); + expr eq = m_builder.mk_eq(lhs1, rhs2); + expr congr_type = Pi(hyps, eq); + // build proof that rhs1 = rhs2 + unsigned i; + for (i = 0; i < kinds.size(); i++) { + if (kinds[i] == congr_arg_kind::Cast && !pinfos[i].is_prop()) + break; + } + if (i == kinds.size()) { + // rhs1 and rhs2 are definitionally equal + expr congr_proof = Fun(hyps, pr1); + return optional(congr_type, congr_proof, to_list(kinds)); + } + buffer rhss1; + get_app_args(rhs1, rhss1); + lean_assert(rhss.size() == rhss1.size()); + expr a = mk_app(fn, i, rhss1.data()); + expr pr2 = m_builder.mk_eq_refl(a); + for (; i < kinds.size(); i++) { + if (kinds[i] == congr_arg_kind::Cast && !pinfos[i].is_prop()) { + lean_assert(pinfos[i].is_subsingleton()); + expr r1 = rhss1[i]; + expr r2 = rhss[i]; + expr r1_eq_r2 = m_builder.mk_app(get_subsingleton_elim_name(), r1, r2); + pr2 = m_builder.mk_congr(pr2, r1_eq_r2); + } else { + pr2 = m_builder.mk_congr_fun(pr2, rhss[i]); + } + } + expr congr_proof = Fun(hyps, m_builder.mk_eq_trans(pr1, pr2)); + return optional(congr_type, congr_proof, to_list(kinds)); + } catch (app_builder_exception &) { + return optional(); + } + } + public: imp(app_builder & b, fun_info_manager & fm, bool ignore_inst_implicit): m_builder(b), m_fmanager(fm), m_ctx(fm.ctx()), m_ignore_inst_implicit(ignore_inst_implicit) {} @@ -202,8 +298,8 @@ public: } optional mk_congr_simp(expr const & fn, unsigned nargs) { - auto r = m_cache.find(key(fn, nargs)); - if (r != m_cache.end()) + auto r = m_simp_cache.find(key(fn, nargs)); + if (r != m_simp_cache.end()) return optional(r->second); fun_info finfo = m_fmanager.get(fn, nargs); list const & result_deps = finfo.get_dependencies(); @@ -238,7 +334,7 @@ public: } auto new_r = mk_congr_simp(fn, pinfos, kinds); if (new_r) { - m_cache.insert(mk_pair(key(fn, nargs), *new_r)); + m_simp_cache.insert(mk_pair(key(fn, nargs), *new_r)); return new_r; } else if (has_cast(kinds)) { // remove casts and try again @@ -248,7 +344,7 @@ public: } auto new_r = mk_congr_simp(fn, pinfos, kinds); if (new_r) { - m_cache.insert(mk_pair(key(fn, nargs), *new_r)); + m_simp_cache.insert(mk_pair(key(fn, nargs), *new_r)); return new_r; } else { return new_r; @@ -257,6 +353,47 @@ public: return new_r; } } + + optional mk_congr(expr const & fn) { + fun_info finfo = m_fmanager.get(fn); + return mk_congr(fn, finfo.get_arity()); + } + + optional mk_congr(expr const & fn, unsigned nargs) { + auto r = m_cache.find(key(fn, nargs)); + if (r != m_cache.end()) + return optional(r->second); + fun_info finfo = m_fmanager.get(fn, nargs); + optional simp_lemma = mk_congr_simp(fn, nargs); + if (!simp_lemma) + return optional(); + buffer kinds; + buffer pinfos; + to_buffer(simp_lemma->get_arg_kinds(), kinds); + to_buffer(finfo.get_params_info(), pinfos); + // For congr lemmas we have the following restriction: + // if a Cast arg is subsingleton, it is not a proposition, + // and it is a dependent argument, then we mark it as fixed. + // This restriction doesn't affect the standard library, + // but it simplifies the implementation. + lean_assert(kinds.size() == pinfos.size()); + bool has_cast = false; + for (unsigned i = 0; i < kinds.size(); i++) { + if (!pinfos[i].is_prop() && pinfos[i].is_subsingleton() && pinfos[i].is_dep()) { + kinds[i] = congr_arg_kind::Fixed; + } + if (kinds[i] == congr_arg_kind::Cast) + has_cast = true; + } + if (!has_cast) { + m_cache.insert(mk_pair(key(fn, nargs), *simp_lemma)); + return simp_lemma; // simp_lemma will be identical to regular congr lemma + } + auto new_r = mk_congr(fn, simp_lemma, pinfos, kinds); + if (new_r) + m_cache.insert(mk_pair(key(fn, nargs), *new_r)); + return new_r; + } }; congr_lemma_manager::congr_lemma_manager(app_builder & b, fun_info_manager & fm, bool ignore_inst_implicit): @@ -272,4 +409,10 @@ auto congr_lemma_manager::mk_congr_simp(expr const & fn) -> optional { auto congr_lemma_manager::mk_congr_simp(expr const & fn, unsigned nargs) -> optional { return m_ptr->mk_congr_simp(fn, nargs); } +auto congr_lemma_manager::mk_congr(expr const & fn) -> optional { + return m_ptr->mk_congr(fn); +} +auto congr_lemma_manager::mk_congr(expr const & fn, unsigned nargs) -> optional { + return m_ptr->mk_congr(fn, nargs); +} } diff --git a/src/library/congr_lemma_manager.h b/src/library/congr_lemma_manager.h index f9849a0875..bdd76c937a 100644 --- a/src/library/congr_lemma_manager.h +++ b/src/library/congr_lemma_manager.h @@ -41,5 +41,8 @@ public: optional mk_congr_simp(expr const & fn); optional mk_congr_simp(expr const & fn, unsigned nargs); + + optional mk_congr(expr const & fn); + optional mk_congr(expr const & fn, unsigned nargs); }; }