diff --git a/src/library/blast/congruence_closure.cpp b/src/library/blast/congruence_closure.cpp index 4361e38ad0..cad434823a 100644 --- a/src/library/blast/congruence_closure.cpp +++ b/src/library/blast/congruence_closure.cpp @@ -79,21 +79,21 @@ ext_congr_lemma::ext_congr_lemma(congr_lemma const & H): m_rel_names(rel_names_from_arg_kinds(H.get_arg_kinds(), get_eq_name())), m_lift_needed(false), m_fixed_fun(true), - m_heq_based(false) {} + m_heq_result(false) {} ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed, bool heq_based): m_R(R), m_congr_lemma(H), m_rel_names(rel_names_from_arg_kinds(H.get_arg_kinds(), get_eq_name())), m_lift_needed(lift_needed), m_fixed_fun(true), - m_heq_based(heq_based) {} + m_heq_result(heq_based) {} ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, list> const & rel_names, bool lift_needed): m_R(R), m_congr_lemma(H), m_rel_names(rel_names), m_lift_needed(lift_needed), m_fixed_fun(true), - m_heq_based(false) {} + m_heq_result(false) {} /* We use the following cache for user-defined lemmas and automatically generated ones. */ typedef std::unordered_map, congr_lemma_key_hash_fn, congr_lemma_key_eq_fn> congr_cache; @@ -381,8 +381,8 @@ static optional mk_ext_specialized_congr_lemma(name const & R, if (R == get_eq_name()) return optional(res1); bool lift_needed = true; - bool heq_based = false; - return optional(R, *eq_congr, lift_needed, heq_based); + bool heq_result = false; + return optional(R, *eq_congr, lift_needed, heq_result); } /* Automatically generated congruence lemma based on heterogeneous equality. */ @@ -391,23 +391,26 @@ static optional mk_hcongr_lemma(name const & R, expr const & fn if (!eq_congr) return optional(); ext_congr_lemma res1(*eq_congr); - /* If all arguments are Eq kind, then we can use generic congr axiom and consider equality for the function. */ - if (eq_congr->all_eq_kind()) - res1.m_fixed_fun = false; - if (R == get_eq_name() || R == get_heq_name()) - return optional(res1); - /* If R is not equality (=) nor heterogeneous equality (==), - we try to lift, but we can only lift if the congruence lemma produces an equality. */ expr type = eq_congr->get_type(); while (is_pi(type)) type = binding_body(type); + /* If all arguments are Eq kind, then we can use generic congr axiom and consider equality for the function. */ + if (!is_heq(type) && eq_congr->all_eq_kind()) + res1.m_fixed_fun = false; lean_assert(is_eq(type) || is_heq(type)); + if (R == get_eq_name() || R == get_heq_name()) { + if (is_heq(type)) + res1.m_heq_result = true; + return optional(res1); + } + /* If R is not equality (=) nor heterogeneous equality (==), + we try to lift, but we can only lift if the congruence lemma produces an equality. */ if (is_heq(type)) { /* We cannot lift heterogeneous equality. */ return optional(); } else { - bool heq_based = !eq_congr->all_eq_kind() || is_heq(type); + bool heq_result = false; bool lift_needed = true; - return optional(R, *eq_congr, lift_needed, heq_based); + return optional(R, *eq_congr, lift_needed, heq_result); } } @@ -684,7 +687,7 @@ void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, exp new_entry.m_cg_root = old_k->m_expr; m_entries.insert(k, new_entry); // 2. Put new equivalence in the TODO queue - bool heq_proof = false; // TODO(Leo): fix this + bool heq_proof = lemma.m_heq_result; push_todo(lemma.m_R, e, old_k->m_expr, *g_congr_mark, heq_proof); } else { m_congruences.insert(k); @@ -1147,7 +1150,7 @@ bool congruence_closure::is_eqv(name const & R, expr const & e1, expr const & e2 return n1->m_root == n2->m_root; } -expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs) const { +expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const { app_builder & b = get_app_builder(); buffer lhs_args, rhs_args; expr const & lhs_fn = get_app_args(lhs, lhs_args); @@ -1163,7 +1166,10 @@ expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, e lean_assert(*it1 && *it2); switch (head(*it2)) { case congr_arg_kind::HEq: - lean_unreachable(); + lemma_args.push_back(lhs_args[i]); + lemma_args.push_back(rhs_args[i]); + lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i])); + break; case congr_arg_kind::Eq: lean_assert(head(*it1)); lemma_args.push_back(lhs_args[i]); @@ -1184,8 +1190,14 @@ expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, e it2 = &(tail(*it2)); } expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args); - if (lemma->m_lift_needed) + if (lemma->m_lift_needed) { + lean_assert(!lemma->m_heq_result); r = b.lift_from_eq(R, r); + } + if (lemma->m_heq_result && !heq_proofs) + r = b.mk_eq_of_heq(r); + else if (!lemma->m_heq_result && heq_proofs) + r = b.mk_heq_of_eq(r); return r; } else { optional r; @@ -1222,7 +1234,7 @@ expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, e } } -expr congruence_closure::mk_congr_proof(name const & R, expr const & e1, expr const & e2) const { +expr congruence_closure::mk_congr_proof(name const & R, expr const & e1, expr const & e2, bool heq_proofs) const { name R1; expr lhs1, rhs1; if (is_equivalence_relation_app(e1, R1, lhs1, rhs1)) { name R2; expr lhs2, rhs2; @@ -1243,16 +1255,16 @@ expr congruence_closure::mk_congr_proof(name const & R, expr const & e1, expr co if (R != get_eq_name()) e1_eqv_new_e1 = b.lift_from_eq(R, e1_eqv_new_e1); } - return b.mk_trans(R, e1_eqv_new_e1, mk_congr_proof_core(R, new_e1, e2)); + return b.mk_trans(R, e1_eqv_new_e1, mk_congr_proof_core(R, new_e1, e2, heq_proofs)); } } } - return mk_congr_proof_core(R, e1, e2); + return mk_congr_proof_core(R, e1, e2, heq_proofs); } -expr congruence_closure::mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H) const { +expr congruence_closure::mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H, bool heq_proofs) const { if (H == *g_congr_mark) { - return mk_congr_proof(R, lhs, rhs); + return mk_congr_proof(R, lhs, rhs, heq_proofs); } else if (H == *g_iff_true_mark) { bool flip; name R1; expr a, b; @@ -1275,13 +1287,13 @@ expr congruence_closure::mk_proof(name const & R, expr const & lhs, expr const & } } -static expr flip_proof(name const & R, expr const & H, bool flipped, bool has_heq_proofs) { +static expr flip_proof(name const & R, expr const & H, bool flipped, bool heq_proofs) { if (H == *g_congr_mark || H == *g_iff_true_mark || H == *g_lift_mark) { return H; } else { auto & b = get_app_builder(); expr new_H = H; - if (has_heq_proofs && is_eq(relaxed_whnf(infer_type(new_H)))) { + if (heq_proofs && is_eq(relaxed_whnf(infer_type(new_H)))) { new_H = b.mk_heq_of_eq(new_H); } if (!flipped) { @@ -1361,13 +1373,13 @@ optional congruence_closure::get_eqv_proof(name const & R, expr const & e1 optional pr; expr lhs = e1; for (unsigned i = 0; i < path1.size(); i++) { - pr = mk_trans(R_trans, pr, mk_proof(R, lhs, path1[i], Hs1[i])); + pr = mk_trans(R_trans, pr, mk_proof(R, lhs, path1[i], Hs1[i], heq_proofs)); lhs = path1[i]; } unsigned i = Hs2.size(); while (i > 0) { --i; - pr = mk_trans(R_trans, pr, mk_proof(R, lhs, path2[i], Hs2[i])); + pr = mk_trans(R_trans, pr, mk_proof(R, lhs, path2[i], Hs2[i], heq_proofs)); lhs = path2[i]; } lean_assert(pr); diff --git a/src/library/blast/congruence_closure.h b/src/library/blast/congruence_closure.h index ce0cc2f9ce..11eee04f83 100644 --- a/src/library/blast/congruence_closure.h +++ b/src/library/blast/congruence_closure.h @@ -146,9 +146,9 @@ class congruence_closure { void add_eqv_core(name const & R, expr const & lhs, expr const & rhs, expr const & H, optional const & added_prop, bool heq_proof); void propagate_no_confusion_eq(expr const & e1, expr const & e2); - expr mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs) const; - expr mk_congr_proof(name const & R, expr const & lhs, expr const & rhs) const; - expr mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H) const; + expr mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const; + expr mk_congr_proof(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const; + expr mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H, bool heq_proofs) const; bool has_heq_proofs(expr const & root) const; @@ -261,10 +261,10 @@ struct ext_congr_lemma { /* If m_fixed_fun is false, then we build equivalences for functions, and use generic congr lemma, and ignore m_congr_lemma. That is, even the function can be treated as an Eq argument. */ unsigned m_fixed_fun:1; - /* If m_uses_heq is true, then lemma is based on heterogeneous equality. */ - unsigned m_heq_based:1; + /* If m_heq_result is true, then lemma is based on heterogeneous equality and the conclusion is a heterogeneous equality. */ + unsigned m_heq_result:1; ext_congr_lemma(congr_lemma const & H); - ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed, bool heq_based); + ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed, bool heq_result); ext_congr_lemma(name const & R, congr_lemma const & H, list> const & rel_names, bool lift_needed); name const & get_relation() const { return m_R; } diff --git a/tests/lean/run/blast_cc_heq3.lean b/tests/lean/run/blast_cc_heq3.lean new file mode 100644 index 0000000000..17569f1db6 --- /dev/null +++ b/tests/lean/run/blast_cc_heq3.lean @@ -0,0 +1,17 @@ +set_option blast.strategy "cc" +set_option blast.cc.heq true -- make sure heterogeneous congruence lemmas are enabled + +axiom vector.{l} : Type.{l} → nat → Type.{l} +axiom app : Π {A : Type} {n m : nat}, vector A m → vector A n → vector A (m+n) + +example (n1 n2 n3 : nat) (v1 w1 : vector nat n1) (w1' : vector nat n3) (v2 w2 : vector nat n2) : + n1 = n3 → v1 = w1 → w1 == w1' → v2 = w2 → app v1 v2 == app w1' w2 := +by blast + +example (n1 n2 n3 : nat) (v1 w1 : vector nat n1) (w1' : vector nat n3) (v2 w2 : vector nat n2) : + n1 == n3 → v1 = w1 → w1 == w1' → v2 == w2 → app v1 v2 == app w1' w2 := +by blast + +example (n1 n2 n3 : nat) (v1 w1 v : vector nat n1) (w1' : vector nat n3) (v2 w2 w : vector nat n2) : + n1 == n3 → v1 = w1 → w1 == w1' → v2 == w2 → app w1' w2 == app v w → app v1 v2 = app v w := +by blast diff --git a/tests/lean/run/blast_cc_heq4.lean b/tests/lean/run/blast_cc_heq4.lean new file mode 100644 index 0000000000..708286a2ba --- /dev/null +++ b/tests/lean/run/blast_cc_heq4.lean @@ -0,0 +1,37 @@ +universes l1 l2 l3 l4 l5 l6 +constants (A : Type.{l1}) (B : A → Type.{l2}) (C : ∀ (a : A) (ba : B a), Type.{l3}) + (D : ∀ (a : A) (ba : B a) (cba : C a ba), Type.{l4}) + (E : ∀ (a : A) (ba : B a) (cba : C a ba) (dcba : D a ba cba), Type.{l5}) + (F : ∀ (a : A) (ba : B a) (cba : C a ba) (dcba : D a ba cba) (edcba : E a ba cba dcba), Type.{l6}) + (C_ss : ∀ a ba, subsingleton (C a ba)) + (a1 a2 a3 : A) + (mk_B1 mk_B2 : ∀ a, B a) + (mk_C1 mk_C2 : ∀ {a} ba, C a ba) + + (tr_B : ∀ {a}, B a → B a) + (x y z : A → A) + + (f f' : ∀ {a : A} {ba : B a} (cba : C a ba), D a ba cba) + (g : ∀ {a : A} {ba : B a} {cba : C a ba} (dcba : D a ba cba), E a ba cba dcba) + (h : ∀ {a : A} {ba : B a} {cba : C a ba} {dcba : D a ba cba} (edcba : E a ba cba dcba), F a ba cba dcba edcba) + +attribute C_ss [instance] + +set_option blast.strategy "cc" +set_option blast.cc.heq true + +example : ∀ {a a' : A}, a == a' → mk_B1 a == mk_B1 a' := +by blast + +example : ∀ {a a' : A}, a == a' → mk_B2 a == mk_B2 a' := +by blast + +example : a1 == y a2 → mk_B1 a1 == mk_B1 (y a2) := +by blast + +example : a1 == x a2 → a2 == y a1 → mk_B1 (x (y a1)) == mk_B1 (x (y (x a2))) := +by blast + +-- The following one needs subsingleton support +-- example : a1 == y a2 → mk_B1 a1 == mk_B2 (y a2) → f (mk_C1 (mk_B2 a1)) == f (mk_C2 (mk_B1 (y a2))) := +-- by blast