diff --git a/src/library/blast/congruence_closure.cpp b/src/library/blast/congruence_closure.cpp index 59ae8e962b..4f139a6233 100644 --- a/src/library/blast/congruence_closure.cpp +++ b/src/library/blast/congruence_closure.cpp @@ -613,6 +613,70 @@ int congruence_closure::compare_root(name const & R, expr e1, expr e2) const { return expr_quick_cmp()(e1, e2); } +int congruence_closure::eq_congr_key_cmp::operator()(eq_congr_key const & k1, eq_congr_key const & k2) const { + lean_assert(g_heq_based); + if (k1.m_hash != k2.m_hash) + return unsigned_cmp()(k1.m_hash, k2.m_hash); + lean_assert(is_app(k1.m_expr) && is_app(k2.m_expr)); + expr const & f1 = app_fn(k1.m_expr); + expr const & a1 = app_arg(k1.m_expr); + expr const & f2 = app_fn(k2.m_expr); + expr const & a2 = app_arg(k2.m_expr); + int r = g_cc->compare_root(get_eq_name(), a1, a2); + if (r != 0) return r; + r = g_cc->compare_root(get_eq_name(), f1, f2); + if (r != 0) return r; + if (is_app(f1) && is_app(f2)) return 0; /* composite */ + if (f1 == f2) return 0; /* identical functions */ + expr f1_type = infer_type(f1); + expr f2_type = infer_type(f2); + if (is_def_eq(f1_type, f2_type)) return 0; /* same function type */ + /* + f1 and f2 are not applications and have different types. + We can't generate a congruence proof in this case because the following lemma + + hcongr : f1 == f2 -> a1 == a2 -> f1 a1 == f2 a2 + + is not provable. Remark: it is also not provable in MLTT, Coq and Agda (even if we assume UIP). + */ + return expr_quick_cmp()(f1, f2); +} + +/* \brief Create a equality congruence table key. + \remark This table and key are only used when heterogeneous equality support is enabled. */ +auto congruence_closure::mk_eq_congr_key(expr const & e) const -> eq_congr_key { + lean_assert(is_app(e)); + eq_congr_key k; + k.m_expr = e; + expr const & f = app_fn(e); + expr const & a = app_arg(e); + unsigned h = hash(get_root(get_eq_name(), f).hash(), get_root(get_eq_name(), a).hash()); + k.m_hash = h; + return k; +} + +int congruence_closure::cmp_eq_iff_keys(congr_key const & k1, congr_key const & k2) const { + lean_assert(k1.m_eq == k2.m_eq); + lean_assert(k1.m_iff == k2.m_iff); + lean_assert(k1.m_eq != k1.m_iff); + name const & R = k1.m_eq ? get_eq_name() : get_iff_name(); + expr const & lhs1 = app_arg(app_fn(k1.m_expr)); + expr const & rhs1 = app_arg(k1.m_expr); + expr const & lhs2 = app_arg(app_fn(k2.m_expr)); + expr const & rhs2 = app_arg(k2.m_expr); + return compare_symm(R, lhs1, rhs1, lhs2, rhs2); +} + +int congruence_closure::cmp_symm_rel_keys(congr_key const & k1, congr_key const & k2) const { + name R1, R2; + expr lhs1, rhs1, lhs2, rhs2; + lean_verify(is_equivalence_relation_app(k1.m_expr, R1, lhs1, rhs1)); + lean_verify(is_equivalence_relation_app(k2.m_expr, R2, lhs2, rhs2)); + if (R1 != R2) + return quick_cmp(R1, R2); + return compare_symm(R1, lhs1, rhs1, lhs2, rhs2); +} + int congruence_closure::congr_key_cmp::operator()(congr_key const & k1, congr_key const & k2) const { if (k1.m_hash != k2.m_hash) return unsigned_cmp()(k1.m_hash, k2.m_hash); @@ -624,67 +688,55 @@ int congruence_closure::congr_key_cmp::operator()(congr_key const & k1, congr_ke return k1.m_iff ? -1 : 1; if (k2.m_symm_rel != k2.m_symm_rel) return k1.m_symm_rel ? -1 : 1; - if (k1.m_eq || k1.m_iff) { - name const & R = k1.m_eq ? get_eq_name() : get_iff_name(); - expr const & lhs1 = app_arg(app_fn(k1.m_expr)); - expr const & rhs1 = app_arg(k1.m_expr); - expr const & lhs2 = app_arg(app_fn(k2.m_expr)); - expr const & rhs2 = app_arg(k2.m_expr); - return g_cc->compare_symm(R, lhs1, rhs1, lhs2, rhs2); - } else if (k1.m_symm_rel) { - name R1, R2; - expr lhs1, rhs1, lhs2, rhs2; - lean_verify(is_equivalence_relation_app(k1.m_expr, R1, lhs1, rhs1)); - lean_verify(is_equivalence_relation_app(k2.m_expr, R2, lhs2, rhs2)); - if (R1 != R2) - return quick_cmp(R1, R2); - return g_cc->compare_symm(R1, lhs1, rhs1, lhs2, rhs2); - } else { - lean_assert(!k1.m_eq && !k2.m_eq && !k1.m_iff && !k2.m_iff && - !k1.m_symm_rel && !k2.m_symm_rel); - lean_assert(k1.m_R == k2.m_R); - buffer args1, args2; - expr const & fn1 = get_app_args(k1.m_expr, args1); - expr const & fn2 = get_app_args(k2.m_expr, args2); - if (args1.size() != args2.size()) - return unsigned_cmp()(args1.size(), args2.size()); - auto lemma = mk_ext_congr_lemma(k1.m_R, k1.m_expr); - lean_assert(lemma); - if (!lemma->m_fixed_fun) { - int r = g_cc->compare_root(get_eq_name(), fn1, fn2); + if (k1.m_eq || k1.m_iff) + return g_cc->cmp_eq_iff_keys(k1, k2); + if (k1.m_symm_rel) + return g_cc->cmp_symm_rel_keys(k1, k2); + + lean_assert(!k1.m_eq && !k2.m_eq && !k1.m_iff && !k2.m_iff && + !k1.m_symm_rel && !k2.m_symm_rel); + lean_assert(k1.m_R == k2.m_R); + buffer args1, args2; + expr const & fn1 = get_app_args(k1.m_expr, args1); + expr const & fn2 = get_app_args(k2.m_expr, args2); + if (args1.size() != args2.size()) + return unsigned_cmp()(args1.size(), args2.size()); + auto lemma = mk_ext_congr_lemma(k1.m_R, k1.m_expr); + lean_assert(lemma); + if (!lemma->m_fixed_fun) { + int r = g_cc->compare_root(get_eq_name(), fn1, fn2); + if (r != 0) return r; + for (unsigned i = 0; i < args1.size(); i++) { + r = g_cc->compare_root(get_eq_name(), args1[i], args2[i]); if (r != 0) return r; - for (unsigned i = 0; i < args1.size(); i++) { - r = g_cc->compare_root(get_eq_name(), args1[i], args2[i]); - if (r != 0) return r; - } - return 0; - } else { - list> const * it1 = &lemma->m_rel_names; - list const * it2 = &lemma->m_congr_lemma.get_arg_kinds(); - int r; - for (unsigned i = 0; i < args1.size(); i++) { - lean_assert(*it1); lean_assert(*it2); - switch (head(*it2)) { - case congr_arg_kind::HEq: - case congr_arg_kind::Eq: - lean_assert(head(*it1)); - r = g_cc->compare_root(*head(*it1), args1[i], args2[i]); - if (r != 0) return r; - break; - case congr_arg_kind::Fixed: - case congr_arg_kind::FixedNoParam: - r = expr_quick_cmp()(args1[i], args2[i]); - if (r != 0) return r; - break; - case congr_arg_kind::Cast: - // do nothing... ignore argument - break; - } - it1 = &(tail(*it1)); - it2 = &(tail(*it2)); - } - return 0; } + return 0; + } else { + list> const * it1 = &lemma->m_rel_names; + list const * it2 = &lemma->m_congr_lemma.get_arg_kinds(); + int r; + for (unsigned i = 0; i < args1.size(); i++) { + lean_assert(*it1); lean_assert(*it2); + switch (head(*it2)) { + case congr_arg_kind::HEq: + case congr_arg_kind::Eq: + lean_assert(head(*it1)); + r = g_cc->compare_root(*head(*it1), args1[i], args2[i]); + if (r != 0) return r; + break; + case congr_arg_kind::Fixed: + case congr_arg_kind::FixedNoParam: + r = expr_quick_cmp()(args1[i], args2[i]); + if (r != 0) return r; + break; + case congr_arg_kind::Cast: + // do nothing... ignore argument + break; + } + it1 = &(tail(*it1)); + it2 = &(tail(*it2)); + } + return 0; } } @@ -773,17 +825,41 @@ void congruence_closure::check_iff_true(congr_key const & k) { push_todo(get_iff_name(), e, mk_true(), *g_iff_true_mark, heq_proof); } +void congruence_closure::add_eq_congruence_table(expr const & e) { + lean_assert(is_app(e)); + lean_assert(g_heq_based); + eq_congr_key k = mk_eq_congr_key(e); + if (auto old_k = m_eq_congruences.find(k)) { + /* + Found new equivalence: e ~ old_k->m_expr + 1. Update m_cg_root field for e + */ + eqc_key k(get_eq_name(), e); + entry new_entry = *m_entries.find(k); + new_entry.m_cg_root = old_k->m_expr; + m_entries.insert(k, new_entry); + /* 2. Put new equivalence in the TODO queue */ + /* TODO(Leo): check if the following line is a bottleneck */ + bool heq_proof = !is_def_eq(infer_type(e), infer_type(old_k->m_expr)); + push_todo(get_eq_name(), e, old_k->m_expr, *g_congr_mark, heq_proof); + } else { + m_eq_congruences.insert(k); + } +} + void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, expr const & e) { lean_assert(is_app(e)); congr_key k = mk_congr_key(lemma, e); if (auto old_k = m_congruences.find(k)) { - // Found new equivalence: e ~ old_k->m_expr - // 1. Update m_cg_root field for e + /* + Found new equivalence: e ~ old_k->m_expr + 1. Update m_cg_root field for e + */ eqc_key k(lemma.m_R, e); entry new_entry = *m_entries.find(k); new_entry.m_cg_root = old_k->m_expr; m_entries.insert(k, new_entry); - // 2. Put new equivalence in the TODO queue + /* 2. Put new equivalence in the TODO queue */ bool heq_proof = false; if (lemma.m_heq_result) { lean_assert(g_heq_based); @@ -893,32 +969,44 @@ void congruence_closure::internalize_core(name R, expr const & e, bool toplevel, case expr_kind::App: { bool is_lapp = is_logical_app(e); mk_entry_core(R, e, to_propagate && !is_lapp); - buffer args; - expr const & fn = get_app_args(e, args); - if (toplevel) { - if (is_lapp) { - to_propagate = true; // we must propagate the children of a top-level logical app (or, and, iff, ite) - } else { - toplevel = false; // children of a non-logical application will not be marked as toplevel - } + if (R == get_eq_name() && g_heq_based) { + bool toplevel = false; + bool to_propagate = false; + internalize_core(R, app_fn(e), toplevel, to_propagate); + internalize_core(R, app_arg(e), toplevel, to_propagate); + add_occurrence(R, e, R, app_fn(e)); + add_occurrence(R, e, R, app_arg(e)); + add_eq_congruence_table(e); } else { - to_propagate = false; - } - if (auto lemma = mk_ext_congr_lemma(R, e)) { - list> const * it = &(lemma->m_rel_names); - for (expr const & arg : args) { - lean_assert(*it); - if (auto R1 = head(*it)) { - internalize_core(*R1, arg, toplevel, to_propagate); - add_occurrence(R, e, *R1, arg); + /* Handle user-defined congruence lemmas, congruence lemmas for other relations, + and automatically generated lemmas for weakly-dependent-functions and relations. */ + buffer args; + expr const & fn = get_app_args(e, args); + if (toplevel) { + if (is_lapp) { + to_propagate = true; // we must propagate the children of a top-level logical app (or, and, iff, ite) + } else { + toplevel = false; // children of a non-logical application will not be marked as toplevel } - it = &tail(*it); + } else { + to_propagate = false; } - if (!lemma->m_fixed_fun) { - internalize_core(get_eq_name(), fn, false, false); - add_occurrence(get_eq_name(), e, get_eq_name(), fn); + if (auto lemma = mk_ext_congr_lemma(R, e)) { + list> const * it = &(lemma->m_rel_names); + for (expr const & arg : args) { + lean_assert(*it); + if (auto R1 = head(*it)) { + internalize_core(*R1, arg, toplevel, to_propagate); + add_occurrence(R, e, *R1, arg); + } + it = &tail(*it); + } + if (!lemma->m_fixed_fun) { + internalize_core(get_eq_name(), fn, false, false); + add_occurrence(get_eq_name(), e, get_eq_name(), fn); + } + add_congruence_table(*lemma, e); } - add_congruence_table(*lemma, e); } apply_simple_eqvs(e); break; @@ -972,10 +1060,15 @@ void congruence_closure::remove_parents(name const & R, expr const & e) { auto ps = m_parents.find(child_key(R, e)); if (!ps) return; ps->for_each([&](parent_occ const & p) { - auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr); - lean_assert(lemma); - congr_key k = mk_congr_key(*lemma, p.m_expr); - m_congruences.erase(k); + if (g_heq_based && R == get_eq_name() && p.m_R == get_eq_name()) { + eq_congr_key k = mk_eq_congr_key(p.m_expr); + m_eq_congruences.erase(k); + } else { + auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr); + lean_assert(lemma); + congr_key k = mk_congr_key(*lemma, p.m_expr); + m_congruences.erase(k); + } }); } @@ -983,9 +1076,13 @@ void congruence_closure::reinsert_parents(name const & R, expr const & e) { auto ps = m_parents.find(child_key(R, e)); if (!ps) return; ps->for_each([&](parent_occ const & p) { - auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr); - lean_assert(lemma); - add_congruence_table(*lemma, p.m_expr); + if (g_heq_based && R == get_eq_name() && p.m_R == get_eq_name()) { + add_eq_congruence_table(p.m_expr); + } else { + auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr); + lean_assert(lemma); + add_congruence_table(*lemma, p.m_expr); + } }); } @@ -1037,18 +1134,20 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con lean_assert(r1 && r2); bool flipped = false; - // We want r2 to be the root of the combined class. + /* We want r2 to be the root of the combined class. */ - // We swap (e1,n1,r1) with (e2,n2,r2) when - // 1- r1->m_interpreted && !r2->m_interpreted. - // Reason: to decide when to propagate we check whether the root of the equivalence class - // is true/false. So, this condition is to make sure if true/false is an equivalence class, - // then one of them is the root. If both are, it doesn't matter, since the state is inconsistent - // anyway. - // 2- r1->m_constructor && !r2->m_interpreted && !r2->m_constructor - // Reason: we want constructors to be the representative of their equivalence classes. - // 2- r1->m_size > r2->m_size && !r2->m_interpreted && !r2->m_constructor - // Reason: performance. + /* + We swap (e1,n1,r1) with (e2,n2,r2) when + 1- r1->m_interpreted && !r2->m_interpreted. + Reason: to decide when to propagate we check whether the root of the equivalence class + is true/false. So, this condition is to make sure if true/false is an equivalence class, + then one of them is the root. If both are, it doesn't matter, since the state is inconsistent + anyway. + 2- r1->m_constructor && !r2->m_interpreted && !r2->m_constructor + Reason: we want constructors to be the representative of their equivalence classes. + 3- r1->m_size > r2->m_size && !r2->m_interpreted && !r2->m_constructor + Reason: performance. + */ if ((r1->m_interpreted && !r2->m_interpreted) || (r1->m_constructor && !r2->m_interpreted && !r2->m_constructor) || (r1->m_size > r2->m_size && !r2->m_interpreted && !r2->m_constructor)) { @@ -1068,21 +1167,23 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con expr e2_root = n2->m_root; entry new_n1 = *n1; - // Following target/proof we have - // e1 -> ... -> r1 - // e2 -> ... -> r2 - // We want - // r1 -> ... -> e1 -> e2 -> ... -> r2 + /* + Following target/proof we have + e1 -> ... -> r1 + e2 -> ... -> r2 + We want + r1 -> ... -> e1 -> e2 -> ... -> r2 + */ invert_trans(R, e1); new_n1.m_target = e2; new_n1.m_proof = H; new_n1.m_flipped = flipped; m_entries.insert(eqc_key(R, e1), new_n1); - // The hash code for the parents is going to change + /* The hash code for the parents is going to change */ remove_parents(R, e1_root); - // force all m_root fields in e1 equivalence class to point to e2_root + /* force all m_root fields in e1 equivalence class to point to e2_root */ bool propagate = R == get_iff_name() && is_true_or_false(e2_root); buffer to_propagate; expr it = e1; @@ -1296,7 +1397,74 @@ bool congruence_closure::is_eqv(name const & R, expr const & e1, expr const & e2 return n1->m_root == n2->m_root; } +expr congruence_closure::mk_eq_congr_proof(expr const & lhs, expr const & rhs, bool heq_proofs) const { + lean_assert(g_heq_based); + app_builder & b = get_app_builder(); + buffer lhs_args, rhs_args; + expr const * lhs_it = &lhs; + expr const * rhs_it = &rhs; + while (is_app(*lhs_it) && is_app(*rhs_it) && *lhs_it != *rhs_it) { + lean_assert(is_eqv(get_eq_name(), *lhs_it, *rhs_it)); + lhs_args.push_back(app_arg(*lhs_it)); + rhs_args.push_back(app_arg(*rhs_it)); + lhs_it = &app_fn(*lhs_it); + rhs_it = &app_fn(*rhs_it); + } + if (lhs_args.empty()) { + if (heq_proofs) + return b.mk_heq_refl(lhs); + else + return b.mk_eq_refl(lhs); + } + std::reverse(lhs_args.begin(), lhs_args.end()); + std::reverse(rhs_args.begin(), rhs_args.end()); + lean_assert(lhs_args.size() == rhs_args.size()); + expr const & lhs_fn = *lhs_it; + expr const & rhs_fn = *rhs_it; + lean_assert(is_eqv(get_eq_name(), lhs_fn, rhs_fn) || is_def_eq(lhs_fn, rhs_fn)); + lean_assert(is_def_eq(infer_type(lhs_fn), infer_type(rhs_fn))); + /* Create proof for + (lhs_fn lhs_args[0] ... lhs_args[n-1]) = (lhs_fn rhs_args[0] ... rhs_args[n-1]) + where + n == lhs_args.size() + */ + auto spec_lemma = mk_ext_hcongr_lemma(lhs_fn, lhs_args.size()); + lean_assert(spec_lemma); + list const * kinds_it = &spec_lemma->m_congr_lemma.get_arg_kinds(); + buffer lemma_args; + for (unsigned i = 0; i < lhs_args.size(); i++) { + lean_assert(kinds_it); + lemma_args.push_back(lhs_args[i]); + lemma_args.push_back(rhs_args[i]); + if (head(*kinds_it) == congr_arg_kind::HEq) { + lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i])); + } else { + lean_assert(head(*kinds_it) == congr_arg_kind::Eq); + lemma_args.push_back(*get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i])); + } + kinds_it = &(tail(*kinds_it)); + } + expr r = mk_app(spec_lemma->m_congr_lemma.get_proof(), lemma_args); + if (spec_lemma->m_heq_result && !heq_proofs) + r = b.mk_eq_of_heq(r); + else if (!spec_lemma->m_heq_result && heq_proofs) + r = b.mk_heq_of_eq(r); + if (is_def_eq(lhs_fn, rhs_fn)) + return r; + /* Convert r into a proof of lhs = rhs using eq.rec and + the proof that lhs_fn = rhs_fn */ + expr lhs_fn_eq_rhs_fn = *get_eqv_proof(get_eq_name(), lhs_fn, rhs_fn); + expr x = mk_fresh_local(infer_type(lhs_fn)); + expr motive_rhs = mk_app(x, rhs_args); + expr motive = heq_proofs ? b.mk_heq(lhs, motive_rhs) : b.mk_eq(lhs, motive_rhs); + return b.mk_eq_rec(Fun(x, motive), r, lhs_fn_eq_rhs_fn); +} + expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const { + if (g_heq_based && (R == get_eq_name() || R == get_heq_name())) { + /* Use general eq congruence lemmas when heterogeneous equality is enabled. */ + return mk_eq_congr_proof(lhs, rhs, heq_proofs); + } app_builder & b = get_app_builder(); buffer lhs_args, rhs_args; expr const & lhs_fn = get_app_args(lhs, lhs_args); @@ -1305,92 +1473,49 @@ expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, e auto lemma = mk_ext_congr_lemma(R, lhs); lean_assert(lemma); if (lemma->m_fixed_fun) { - if (g_heq_based && lemma->m_hcongr_lemma && (R == get_eq_name() || R == get_heq_name())) { - /* Try to simplify congruence proof by consuming common prefix of lhs and rhs */ - /* This branch is an optimization, and it is not necessary */ - unsigned i = 0; - for (; i < lhs_args.size(); i++) { - if (!is_def_eq(lhs_args[i], rhs_args[i])) - break; - } - unsigned prefix_sz = i; - unsigned rest_sz = lhs_args.size() - prefix_sz; - if (rest_sz == 0) { - if (heq_proofs) - return b.mk_heq_refl(lhs); - else - return b.mk_eq_refl(lhs); - } - expr g = lhs; - for (unsigned i = 0; i < rest_sz; i++) g = app_fn(g); - auto spec_lemma = mk_ext_hcongr_lemma(g, rest_sz); - lean_assert(spec_lemma); - list const * it = &spec_lemma->m_congr_lemma.get_arg_kinds(); - buffer lemma_args; - for (unsigned i = prefix_sz; i < lhs_args.size(); i++) { - lean_assert(it); + /* Main case: convers user-defined congruence lemmas, and + all automatically generated congruence lemmas */ + list> const * it1 = &lemma->m_rel_names; + list const * it2 = &lemma->m_congr_lemma.get_arg_kinds(); + buffer lemma_args; + for (unsigned i = 0; i < lhs_args.size(); i++) { + lean_assert(*it1 && *it2); + switch (head(*it2)) { + case congr_arg_kind::HEq: lemma_args.push_back(lhs_args[i]); lemma_args.push_back(rhs_args[i]); - if (head(*it) == congr_arg_kind::HEq) { - lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i])); - } else { - lean_assert(head(*it) == congr_arg_kind::Eq); - lemma_args.push_back(*get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i])); - } - it = &(tail(*it)); + lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i])); + break; + case congr_arg_kind::Eq: + lean_assert(head(*it1)); + lemma_args.push_back(lhs_args[i]); + lemma_args.push_back(rhs_args[i]); + lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i])); + break; + case congr_arg_kind::Fixed: + lemma_args.push_back(lhs_args[i]); + break; + case congr_arg_kind::FixedNoParam: + break; + case congr_arg_kind::Cast: + lemma_args.push_back(lhs_args[i]); + lemma_args.push_back(rhs_args[i]); + break; } - expr r = mk_app(spec_lemma->m_congr_lemma.get_proof(), lemma_args); - lean_assert(g_heq_based); - if (spec_lemma->m_heq_result && !heq_proofs) - r = b.mk_eq_of_heq(r); - else if (!spec_lemma->m_heq_result && heq_proofs) - r = b.mk_heq_of_eq(r); - return r; - } else { - /* Main case: convers user-defined congruence lemmas, and - all automatically generated congruence lemmas */ - list> const * it1 = &lemma->m_rel_names; - list const * it2 = &lemma->m_congr_lemma.get_arg_kinds(); - buffer lemma_args; - for (unsigned i = 0; i < lhs_args.size(); i++) { - lean_assert(*it1 && *it2); - switch (head(*it2)) { - case congr_arg_kind::HEq: - lemma_args.push_back(lhs_args[i]); - lemma_args.push_back(rhs_args[i]); - lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i])); - break; - case congr_arg_kind::Eq: - lean_assert(head(*it1)); - lemma_args.push_back(lhs_args[i]); - lemma_args.push_back(rhs_args[i]); - lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i])); - break; - case congr_arg_kind::Fixed: - lemma_args.push_back(lhs_args[i]); - break; - case congr_arg_kind::FixedNoParam: - break; - case congr_arg_kind::Cast: - lemma_args.push_back(lhs_args[i]); - lemma_args.push_back(rhs_args[i]); - break; - } - it1 = &(tail(*it1)); - it2 = &(tail(*it2)); - } - expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args); - if (lemma->m_lift_needed) { - r = b.lift_from_eq(R, r); - } - if (g_heq_based) { - if (lemma->m_heq_result && !heq_proofs) - r = b.mk_eq_of_heq(r); - else if (!lemma->m_heq_result && heq_proofs) - r = b.mk_heq_of_eq(r); - } - return r; + it1 = &(tail(*it1)); + it2 = &(tail(*it2)); } + expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args); + if (lemma->m_lift_needed) { + r = b.lift_from_eq(R, r); + } + if (g_heq_based) { + if (lemma->m_heq_result && !heq_proofs) + r = b.mk_eq_of_heq(r); + else if (!lemma->m_heq_result && heq_proofs) + r = b.mk_heq_of_eq(r); + } + return r; } else { /* This branch builds congruence proofs that handle equality between functions. The proof is created using congr_arg/congr_fun/congr lemmas. diff --git a/src/library/blast/congruence_closure.h b/src/library/blast/congruence_closure.h index 06c513d936..27b19ce304 100644 --- a/src/library/blast/congruence_closure.h +++ b/src/library/blast/congruence_closure.h @@ -76,7 +76,22 @@ class congruence_closure { } }; - /* Key for the congruence set */ + /* Key for the equality congruence table. + + \remark We only use the equality congruence table when the support for heterogeneous equality + is turned on (see blast.cc.heq option). Otherwise, we store equality congruences in the + generic congruence table and rely on automatically generated congruence lemmas for + weakly dependent functions. */ + struct eq_congr_key { + expr m_expr; + unsigned m_hash; + }; + + struct eq_congr_key_cmp { + int operator()(eq_congr_key const & k1, eq_congr_key const & k2) const; + }; + + /* Key for the congruence table. */ struct congr_key { name m_R; expr m_expr; @@ -93,6 +108,9 @@ class congruence_closure { congr_key() { m_eq = 0; m_iff = 0; m_symm_rel = 0; } }; + int cmp_eq_iff_keys(congr_key const & k1, congr_key const & k2) const; + int cmp_symm_rel_keys(congr_key const & k1, congr_key const & k2) const; + struct congr_key_cmp { int operator()(congr_key const & k1, congr_key const & k2) const; }; @@ -105,10 +123,12 @@ class congruence_closure { typedef eqc_key_cmp parent_occ_cmp; typedef rb_tree parent_occ_set; typedef rb_map parents; + typedef rb_tree eq_congruences; typedef rb_tree congruences; typedef rb_map subsingleton_reprs; entries m_entries; parents m_parents; + eq_congruences m_eq_congruences; congruences m_congruences; /** The following mapping store a representative for each subsingleton type */ subsingleton_reprs m_subsingleton_reprs; @@ -129,6 +149,7 @@ class congruence_closure { int compare_symm(name const & R, expr lhs1, expr rhs1, expr lhs2, expr rhs2) const; int compare_root(name const & R, expr e1, expr e2) const; + eq_congr_key mk_eq_congr_key(expr const & e) const; unsigned symm_hash(name const & R, expr const & lhs, expr const & rhs) const; congr_key mk_congr_key(ext_congr_lemma const & lemma, expr const & e) const; void check_iff_true(congr_key const & k); @@ -140,6 +161,7 @@ class congruence_closure { void mk_entry_core(name const & R, expr const & e, bool to_propagate); void mk_entry(name const & R, expr const & e, bool to_propagate); void add_occurrence(name const & Rp, expr const & parent, name const & Rc, expr const & child); + void add_eq_congruence_table(expr const & e); void add_congruence_table(ext_congr_lemma const & lemma, expr const & e); void invert_trans(name const & R, expr const & e, bool new_flipped, optional new_target, optional new_proof); void invert_trans(name const & R, expr const & e); @@ -152,6 +174,7 @@ class congruence_closure { void add_eqv_core(name const & R, expr const & lhs, expr const & rhs, expr const & H, optional const & added_prop, bool heq_proof); void propagate_no_confusion_eq(expr const & e1, expr const & e2); + expr mk_eq_congr_proof(expr const & lhs, expr const & rhs, bool heq_proofs) const; expr mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const; expr mk_congr_proof(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const; expr mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H, bool heq_proofs) const; diff --git a/tests/lean/run/blast_cc_heq9.lean b/tests/lean/run/blast_cc_heq9.lean new file mode 100644 index 0000000000..f991435953 --- /dev/null +++ b/tests/lean/run/blast_cc_heq9.lean @@ -0,0 +1,8 @@ +example (f g : Π {A : Type₁}, A → A → A) (h : nat → nat) (a b : nat) : + h = f a → h b = f a b := +by blast + + +example (f g : Π {A : Type₁} (a b : A), A) (h : nat → nat) (a b : nat) : + h = f a → h b = f a b := +by blast