diff --git a/src/library/blast/congruence_closure.cpp b/src/library/blast/congruence_closure.cpp index cad434823a..650c21bbd7 100644 --- a/src/library/blast/congruence_closure.cpp +++ b/src/library/blast/congruence_closure.cpp @@ -79,21 +79,24 @@ ext_congr_lemma::ext_congr_lemma(congr_lemma const & H): m_rel_names(rel_names_from_arg_kinds(H.get_arg_kinds(), get_eq_name())), m_lift_needed(false), m_fixed_fun(true), - m_heq_result(false) {} -ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed, bool heq_based): + m_heq_result(false), + m_hcongr_lemma(false) {} +ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed): m_R(R), m_congr_lemma(H), m_rel_names(rel_names_from_arg_kinds(H.get_arg_kinds(), get_eq_name())), m_lift_needed(lift_needed), m_fixed_fun(true), - m_heq_result(heq_based) {} + m_heq_result(false), + m_hcongr_lemma(false) {} ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, list> const & rel_names, bool lift_needed): m_R(R), m_congr_lemma(H), m_rel_names(rel_names), m_lift_needed(lift_needed), m_fixed_fun(true), - m_heq_result(false) {} + m_heq_result(false), + m_hcongr_lemma(false) {} /* We use the following cache for user-defined lemmas and automatically generated ones. */ typedef std::unordered_map, congr_lemma_key_hash_fn, congr_lemma_key_eq_fn> congr_cache; @@ -381,12 +384,11 @@ static optional mk_ext_specialized_congr_lemma(name const & R, if (R == get_eq_name()) return optional(res1); bool lift_needed = true; - bool heq_result = false; - return optional(R, *eq_congr, lift_needed, heq_result); + return optional(R, *eq_congr, lift_needed); } /* Automatically generated congruence lemma based on heterogeneous equality. */ -static optional mk_hcongr_lemma(name const & R, expr const & fn, unsigned nargs) { +static optional mk_hcongr_lemma_core(name const & R, expr const & fn, unsigned nargs) { optional eq_congr = mk_hcongr_lemma(fn, nargs); if (!eq_congr) return optional(); @@ -398,6 +400,7 @@ static optional mk_hcongr_lemma(name const & R, expr const & fn res1.m_fixed_fun = false; lean_assert(is_eq(type) || is_heq(type)); if (R == get_eq_name() || R == get_heq_name()) { + res1.m_hcongr_lemma = true; if (is_heq(type)) res1.m_heq_result = true; return optional(res1); @@ -408,9 +411,10 @@ static optional mk_hcongr_lemma(name const & R, expr const & fn /* We cannot lift heterogeneous equality. */ return optional(); } else { - bool heq_result = false; bool lift_needed = true; - return optional(R, *eq_congr, lift_needed, heq_result); + ext_congr_lemma res2(R, *eq_congr, lift_needed); + res2.m_hcongr_lemma = true; + return optional(res2); } } @@ -429,7 +433,7 @@ optional mk_ext_congr_lemma(name const & R, expr const & e) { /* Try automatically generated lemma for equivalence relation over iff/eq */ if (!lemma) lemma = mk_relation_congr_lemma(R, fn, nargs); /* Try automatically generated congruence lemma with support for heterogeneous equality. */ - if (!lemma) lemma = mk_hcongr_lemma(R, fn, nargs); + if (!lemma) lemma = mk_hcongr_lemma_core(R, fn, nargs); if (lemma) { /* succeeded */ @@ -475,6 +479,23 @@ optional mk_ext_congr_lemma(name const & R, expr const & e) { return optional(); } +optional mk_ext_hcongr_lemma(expr const & fn, unsigned nargs) { + congr_lemma_key key1(get_eq_name(), fn, nargs); + auto it1 = g_congr_cache->find(key1); + if (it1 != g_congr_cache->end()) + return it1->second; + + if (auto lemma = mk_hcongr_lemma_core(get_eq_name(), fn, nargs)) { + /* succeeded */ + g_congr_cache->insert(mk_pair(key1, lemma)); + return lemma; + } + + /* cache failure */ + g_congr_cache->insert(mk_pair(key1, optional())); + return optional(); +} + void congruence_closure::update_non_eq_relations(name const & R) { if (R == get_eq_name()) return; @@ -687,7 +708,12 @@ void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, exp new_entry.m_cg_root = old_k->m_expr; m_entries.insert(k, new_entry); // 2. Put new equivalence in the TODO queue - bool heq_proof = lemma.m_heq_result; + bool heq_proof = false; + if (lemma.m_heq_result) { + lean_assert(g_heq_based); + if (!is_def_eq(infer_type(e), infer_type(old_k->m_expr))) + heq_proof = true; + } push_todo(lemma.m_R, e, old_k->m_expr, *g_congr_mark, heq_proof); } else { m_congruences.insert(k); @@ -1159,47 +1185,96 @@ expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, e auto lemma = mk_ext_congr_lemma(R, lhs); lean_assert(lemma); if (lemma->m_fixed_fun) { - list> const * it1 = &lemma->m_rel_names; - list const * it2 = &lemma->m_congr_lemma.get_arg_kinds(); - buffer lemma_args; - for (unsigned i = 0; i < lhs_args.size(); i++) { - lean_assert(*it1 && *it2); - switch (head(*it2)) { - case congr_arg_kind::HEq: - lemma_args.push_back(lhs_args[i]); - lemma_args.push_back(rhs_args[i]); - lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i])); - break; - case congr_arg_kind::Eq: - lean_assert(head(*it1)); - lemma_args.push_back(lhs_args[i]); - lemma_args.push_back(rhs_args[i]); - lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i])); - break; - case congr_arg_kind::Fixed: - lemma_args.push_back(lhs_args[i]); - break; - case congr_arg_kind::FixedNoParam: - break; - case congr_arg_kind::Cast: - lemma_args.push_back(lhs_args[i]); - lemma_args.push_back(rhs_args[i]); - break; + if (g_heq_based && lemma->m_hcongr_lemma && (R == get_eq_name() || R == get_heq_name())) { + /* Try to simplify congruence proof by consuming common prefix of lhs and rhs */ + /* This branch is an optimization, and it is not necessary */ + unsigned i = 0; + for (; i < lhs_args.size(); i++) { + if (!is_def_eq(lhs_args[i], rhs_args[i])) + break; } - it1 = &(tail(*it1)); - it2 = &(tail(*it2)); + unsigned prefix_sz = i; + unsigned rest_sz = lhs_args.size() - prefix_sz; + if (rest_sz == 0) { + if (heq_proofs) + return b.mk_heq_refl(lhs); + else + return b.mk_eq_refl(lhs); + } + expr g = lhs; + for (unsigned i = 0; i < rest_sz; i++) g = app_fn(g); + auto spec_lemma = mk_ext_hcongr_lemma(g, rest_sz); + lean_assert(spec_lemma); + list const * it = &spec_lemma->m_congr_lemma.get_arg_kinds(); + buffer lemma_args; + for (unsigned i = prefix_sz; i < lhs_args.size(); i++) { + lean_assert(it); + lemma_args.push_back(lhs_args[i]); + lemma_args.push_back(rhs_args[i]); + if (head(*it) == congr_arg_kind::HEq) { + lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i])); + } else { + lean_assert(head(*it) == congr_arg_kind::Eq); + lemma_args.push_back(*get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i])); + } + it = &(tail(*it)); + } + expr r = mk_app(spec_lemma->m_congr_lemma.get_proof(), lemma_args); + if (spec_lemma->m_heq_result && !heq_proofs) + r = b.mk_eq_of_heq(r); + else if (!spec_lemma->m_heq_result && heq_proofs) + r = b.mk_heq_of_eq(r); + return r; + } else { + /* Main case: convers user-defined congruence lemmas, and + all automatically generated congruence lemmas */ + list> const * it1 = &lemma->m_rel_names; + list const * it2 = &lemma->m_congr_lemma.get_arg_kinds(); + buffer lemma_args; + for (unsigned i = 0; i < lhs_args.size(); i++) { + lean_assert(*it1 && *it2); + switch (head(*it2)) { + case congr_arg_kind::HEq: + lemma_args.push_back(lhs_args[i]); + lemma_args.push_back(rhs_args[i]); + lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i])); + break; + case congr_arg_kind::Eq: + lean_assert(head(*it1)); + lemma_args.push_back(lhs_args[i]); + lemma_args.push_back(rhs_args[i]); + lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i])); + break; + case congr_arg_kind::Fixed: + lemma_args.push_back(lhs_args[i]); + break; + case congr_arg_kind::FixedNoParam: + break; + case congr_arg_kind::Cast: + lemma_args.push_back(lhs_args[i]); + lemma_args.push_back(rhs_args[i]); + break; + } + it1 = &(tail(*it1)); + it2 = &(tail(*it2)); + } + expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args); + if (lemma->m_lift_needed) { + lean_assert(!lemma->m_heq_result); + r = b.lift_from_eq(R, r); + } + if (lemma->m_heq_result && !heq_proofs) + r = b.mk_eq_of_heq(r); + else if (!lemma->m_heq_result && heq_proofs) + r = b.mk_heq_of_eq(r); + return r; } - expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args); - if (lemma->m_lift_needed) { - lean_assert(!lemma->m_heq_result); - r = b.lift_from_eq(R, r); - } - if (lemma->m_heq_result && !heq_proofs) - r = b.mk_eq_of_heq(r); - else if (!lemma->m_heq_result && heq_proofs) - r = b.mk_heq_of_eq(r); - return r; } else { + /* This branch builds congruence proofs that handle equality between functions. + The proof is created using congr_arg/congr_fun/congr lemmas. + It can build proofs for congruence such as: + f = g -> a = b -> f a = g b + but it is limited to simply typed functions. */ optional r; unsigned i = 0; if (!is_def_eq(lhs_fn, rhs_fn)) { diff --git a/src/library/blast/congruence_closure.h b/src/library/blast/congruence_closure.h index 11eee04f83..ab76d27f9d 100644 --- a/src/library/blast/congruence_closure.h +++ b/src/library/blast/congruence_closure.h @@ -263,8 +263,10 @@ struct ext_congr_lemma { unsigned m_fixed_fun:1; /* If m_heq_result is true, then lemma is based on heterogeneous equality and the conclusion is a heterogeneous equality. */ unsigned m_heq_result:1; + /* If m_heq_lemma is true, then lemma was created using mk_hcongr_lemma. */ + unsigned m_hcongr_lemma:1; ext_congr_lemma(congr_lemma const & H); - ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed, bool heq_result); + ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed); ext_congr_lemma(name const & R, congr_lemma const & H, list> const & rel_names, bool lift_needed); name const & get_relation() const { return m_R; } diff --git a/tests/lean/run/blast_cc_heq5.lean b/tests/lean/run/blast_cc_heq5.lean new file mode 100644 index 0000000000..4708833caa --- /dev/null +++ b/tests/lean/run/blast_cc_heq5.lean @@ -0,0 +1,8 @@ +set_option blast.strategy "cc" +set_option blast.cc.heq true + +definition ex1 (a b c a' b' c' : nat) : a = a' → b = b' → c = c' → a + b + c + a = a' + b' + c' + a' := +by blast + +set_option pp.beta true +print ex1