From cb02d1deae6fe95af0917f9545cd9e7936fd41a2 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 6 Jan 2016 17:30:20 -0800 Subject: [PATCH] feat(library/blast/congruence_closure): add support for specialized congr lemmas in the congruence closure module --- src/library/blast/blast.cpp | 9 +++ src/library/blast/blast.h | 2 + src/library/blast/congruence_closure.cpp | 81 ++++++++++++++++------ src/library/blast/congruence_closure.h | 4 +- src/library/blast/forward/ematch.cpp | 2 +- tests/lean/run/blast_cc_subsingleton1.lean | 28 ++++++++ tests/lean/run/blast_cc_subsingleton2.lean | 14 ++++ 7 files changed, 114 insertions(+), 26 deletions(-) create mode 100644 tests/lean/run/blast_cc_subsingleton1.lean create mode 100644 tests/lean/run/blast_cc_subsingleton2.lean diff --git a/src/library/blast/blast.cpp b/src/library/blast/blast.cpp index f4475ad620..78a0296e15 100644 --- a/src/library/blast/blast.cpp +++ b/src/library/blast/blast.cpp @@ -755,6 +755,10 @@ public: return m_fun_info_manager.get_specialized(a); } + unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs) { + return m_fun_info_manager.get_specialization_prefix_size(fn, nargs); + } + unsigned abstract_hash(expr const & e) { return m_abstract_expr_manager.hash(e); } @@ -1140,6 +1144,11 @@ fun_info get_specialized_fun_info(expr const & a) { return g_blastenv->get_specialized_fun_info(a); } +unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs) { + lean_assert(g_blastenv); + return g_blastenv->get_specialization_prefix_size(fn, nargs); +} + unsigned abstract_hash(expr const & e) { lean_assert(g_blastenv); return g_blastenv->abstract_hash(e); diff --git a/src/library/blast/blast.h b/src/library/blast/blast.h index 27fa2c1307..bdd64f1b49 100644 --- a/src/library/blast/blast.h +++ b/src/library/blast/blast.h @@ -160,6 +160,8 @@ fun_info get_fun_info(expr const & fn, unsigned nargs); taking into account the actual arguments. \pre is_app(a) */ fun_info get_specialized_fun_info(expr const & a); +/** \brief Return the given function specialization prefix size. */ +unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs); /** \brief Hash and equality test for abstract expressions */ unsigned abstract_hash(expr const & e); diff --git a/src/library/blast/congruence_closure.cpp b/src/library/blast/congruence_closure.cpp index 4da18f492b..ae8f3148f4 100644 --- a/src/library/blast/congruence_closure.cpp +++ b/src/library/blast/congruence_closure.cpp @@ -192,7 +192,6 @@ static optional to_ext_congr_lemma(name const & R, expr const & Rcs.resize(lhs_args.size(), optional()); r_hyps.resize(lhs_args.size(), none_expr()); // Set Fixed args - // TODO(Leo): handle FixedNoParam case? for (unsigned i = 0; i < lhs_args.size(); i++) { if (lhs_args[i] == rhs_args[i]) kinds[i] = congr_arg_kind::Fixed; @@ -243,7 +242,8 @@ static optional to_ext_congr_lemma(name const & R, expr const & } switch (kinds[i]) { case congr_arg_kind::FixedNoParam: - // TODO(Leo): revise this code + // User defined congruence rules do not use FixedNoParam + lean_unreachable(); break; case congr_arg_kind::Fixed: break; @@ -283,7 +283,7 @@ static optional to_ext_congr_lemma(name const & R, expr const & return optional(R, new_lemma, to_list(Rcs), lift_needed); } -static optional mk_ext_congr_lemma_core(name const & R, expr const & fn, unsigned nargs) { +static optional mk_ext_user_congr_lemma(name const & R, expr const & fn, unsigned nargs) { simp_lemmas_for const * sr = get_simp_lemmas().find(R); if (sr) { list const * crs = sr->find_congr(fn); @@ -294,8 +294,11 @@ static optional mk_ext_congr_lemma_core(name const & R, expr co } } } + return optional(); +} - // Automatically generated lemma for equivalence relation over iff/eq +/* Automatically generated lemma for equivalence relation over iff/eq. */ +static optional mk_relation_congr_lemma(name const & R, expr const & fn, unsigned nargs) { if (auto info = is_relation(fn)) { if (info->get_arity() == nargs) { if (R == get_iff_name()) { @@ -311,9 +314,13 @@ static optional mk_ext_congr_lemma_core(name const & R, expr co } } } + return optional(); +} - // Automatically generated lemma - optional eq_congr = mk_congr_lemma(fn, nargs); +/* Automatically generated lemma for function application \c e. The lemma is specialized using the + specialization prefix for \c e. */ +static optional mk_ext_specialized_congr_lemma(name const & R, expr const & e) { + optional eq_congr = mk_specialized_congr_lemma(e); if (!eq_congr) return optional(); ext_congr_lemma res1(*eq_congr); @@ -326,14 +333,45 @@ static optional mk_ext_congr_lemma_core(name const & R, expr co return optional(R, *eq_congr, lift_needed); } -optional mk_ext_congr_lemma(name const & R, expr const & fn, unsigned nargs) { - congr_lemma_key key(R, fn, nargs); - auto it = g_congr_cache->find(key); - if (it != g_congr_cache->end()) - return it->second; - auto r = mk_ext_congr_lemma_core(R, fn, nargs); - g_congr_cache->insert(mk_pair(key, r)); - return r; +optional mk_ext_congr_lemma(name const & R, expr const & e) { + expr const & fn = get_app_fn(e); + unsigned nargs = get_app_num_args(e); + /* Check if (R, fn, nargs) is in the cache */ + congr_lemma_key key1(R, fn, nargs); + auto it1 = g_congr_cache->find(key1); + if (it1 != g_congr_cache->end()) + return it1->second; + /* Check if (g := fn+specialization prefix) is in the cache */ + unsigned prefix_sz = get_specialization_prefix_size(fn, nargs); + unsigned rest_nargs = nargs - prefix_sz; + expr g = e; + for (unsigned i = 0; i < rest_nargs; i++) g = app_fn(g); + congr_lemma_key key2(R, g, rest_nargs); + auto it2 = g_congr_cache->find(key2); + if (it2 != g_congr_cache->end()) + return it2->second; + /* Check if there is user defined lemma for (R, fn, nargs). + Remark: specialization prefix is irrelevan for used defined congruence lemmas. */ + if (auto lemma = mk_ext_user_congr_lemma(R, fn, nargs)) { + g_congr_cache->insert(mk_pair(key1, lemma)); + return lemma; + } + /* Try automatically generated lemma for equivalence relation over iff/eq */ + if (auto lemma = mk_relation_congr_lemma(R, fn, nargs)) { + g_congr_cache->insert(mk_pair(key1, lemma)); + return lemma; + } + /* Try automatically generated specialized congruence lemma */ + if (auto lemma = mk_ext_specialized_congr_lemma(R, e)) { + if (prefix_sz == 0) + g_congr_cache->insert(mk_pair(key1, lemma)); + else + g_congr_cache->insert(mk_pair(key2, lemma)); + return lemma; + } + /* cache failure */ + g_congr_cache->insert(mk_pair(key1, optional())); + return optional(); } void congruence_closure::update_non_eq_relations(name const & R) { @@ -412,7 +450,7 @@ int congruence_closure::congr_key_cmp::operator()(congr_key const & k1, congr_ke expr const & fn2 = get_app_args(k2.m_expr, args2); if (args1.size() != args2.size()) return unsigned_cmp()(args1.size(), args2.size()); - auto lemma = mk_ext_congr_lemma(k1.m_R, fn1, args1.size()); + auto lemma = mk_ext_congr_lemma(k1.m_R, k1.m_expr); lean_assert(lemma); if (!lemma->m_fixed_fun) { int r = g_cc->compare_root(get_eq_name(), fn1, fn2); @@ -552,6 +590,7 @@ void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, exp check_iff_true(k); } +// TODO(Leo): this should not be hard-coded static bool is_logical_app(expr const & n) { if (!is_app(n)) return false; expr const & fn = get_app_fn(n); @@ -614,7 +653,7 @@ void congruence_closure::internalize_core(name const & R, expr const & e, bool t } else { to_propagate = false; } - if (auto lemma = mk_ext_congr_lemma(R, fn, args.size())) { + if (auto lemma = mk_ext_congr_lemma(R, e)) { list> const * it = &(lemma->m_rel_names); for (expr const & arg : args) { lean_assert(*it); @@ -681,9 +720,7 @@ void congruence_closure::remove_parents(name const & R, expr const & e) { auto ps = m_parents.find(child_key(R, e)); if (!ps) return; ps->for_each([&](parent_occ const & p) { - expr const & fn = get_app_fn(p.m_expr); - unsigned nargs = get_app_num_args(p.m_expr); - auto lemma = mk_ext_congr_lemma(p.m_R, fn, nargs); + auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr); lean_assert(lemma); congr_key k = mk_congr_key(*lemma, p.m_expr); m_congruences.erase(k); @@ -694,9 +731,7 @@ void congruence_closure::reinsert_parents(name const & R, expr const & e) { auto ps = m_parents.find(child_key(R, e)); if (!ps) return; ps->for_each([&](parent_occ const & p) { - expr const & fn = get_app_fn(p.m_expr); - unsigned nargs = get_app_num_args(p.m_expr); - auto lemma = mk_ext_congr_lemma(p.m_R, fn, nargs); + auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr); lean_assert(lemma); add_congruence_table(*lemma, p.m_expr); }); @@ -955,7 +990,7 @@ expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, e expr const & lhs_fn = get_app_args(lhs, lhs_args); expr const & rhs_fn = get_app_args(rhs, rhs_args); lean_assert(lhs_args.size() == rhs_args.size()); - auto lemma = mk_ext_congr_lemma(R, lhs_fn, lhs_args.size()); + auto lemma = mk_ext_congr_lemma(R, lhs); lean_assert(lemma); if (lemma->m_fixed_fun) { list> const * it1 = &lemma->m_rel_names; diff --git a/src/library/blast/congruence_closure.h b/src/library/blast/congruence_closure.h index 5f3f6b1409..d09e1829e4 100644 --- a/src/library/blast/congruence_closure.h +++ b/src/library/blast/congruence_closure.h @@ -261,9 +261,9 @@ struct ext_congr_lemma { list> const & get_arg_rel_names() const { return m_rel_names; } }; -/** \brief Build an extended congruence lemma for function \c fn with \c nargs expected arguments over relation \c R. +/** \brief Build an extended congruence lemma for function the function application \c e over relation \c R. A subset of user-defined congruence lemmas is considered by this procedure. */ -optional mk_ext_congr_lemma(name const & R, expr const & fn, unsigned nargs); +optional mk_ext_congr_lemma(name const & R, expr const & e); void initialize_congruence_closure(); void finalize_congruence_closure(); diff --git a/src/library/blast/forward/ematch.cpp b/src/library/blast/forward/ematch.cpp index 542cc60bdf..4bc4518581 100644 --- a/src/library/blast/forward/ematch.cpp +++ b/src/library/blast/forward/ematch.cpp @@ -293,7 +293,7 @@ struct ematch_fn { } bool match_args(state & s, name const & R, buffer const & p_args, expr const & t) { - optional cg_lemma = mk_ext_congr_lemma(R, get_app_fn(t), p_args.size()); + optional cg_lemma = mk_ext_congr_lemma(R, t); if (!cg_lemma) return false; buffer t_args; diff --git a/tests/lean/run/blast_cc_subsingleton1.lean b/tests/lean/run/blast_cc_subsingleton1.lean new file mode 100644 index 0000000000..663eedb417 --- /dev/null +++ b/tests/lean/run/blast_cc_subsingleton1.lean @@ -0,0 +1,28 @@ +import data.unit +open nat unit + +constant f {A : Type} (a : A) {B : Type} (b : B) : nat + +constant g : unit → nat + +set_option blast.strategy "cc" + +example (a b : unit) : g a = g b := +by blast + +example (a c : unit) (b d : nat) : b = d → f a b = f c d := +by blast + +constant h {A B : Type} : A → B → nat + +example (a b c d : unit) : h a b = h c d := +by blast + +definition C [reducible] : nat → Type₁ +| nat.zero := unit +| (nat.succ a) := nat + +constant g₂ : Π {n : nat}, C n → nat → nat + +example (a b : C zero) (c d : nat) : c = d → g₂ a c = g₂ b d := +by blast diff --git a/tests/lean/run/blast_cc_subsingleton2.lean b/tests/lean/run/blast_cc_subsingleton2.lean new file mode 100644 index 0000000000..fb4f9bb023 --- /dev/null +++ b/tests/lean/run/blast_cc_subsingleton2.lean @@ -0,0 +1,14 @@ +import data.unit +open nat unit + +set_option blast.strategy "cc" + +constant r {A B : Type} : A → B → A + +definition ex1 (a b c d : unit) : r a b = r c d := +by blast + +-- The congruence closure module does not automatically merge subsingleton equivalence classes. +-- +-- example (a b : unit) : a = b := +-- by blast