From ac6bfce01cf1ae4a4a50d55dd4ecfbb22cb9cb6e Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 24 Jan 2017 12:09:37 -0800 Subject: [PATCH] feat(library/tactic/smt/congruence_closure): improve propagation for beta reduction in the congruence closure module --- library/init/meta/smt/congruence_closure.lean | 2 +- src/library/tactic/smt/congruence_closure.cpp | 107 +++++++++++++----- src/library/tactic/smt/congruence_closure.h | 4 +- tests/lean/run/cc_beta.lean | 68 +++++++++++ 4 files changed, 149 insertions(+), 32 deletions(-) create mode 100644 tests/lean/run/cc_beta.lean diff --git a/library/init/meta/smt/congruence_closure.lean b/library/init/meta/smt/congruence_closure.lean index 63d4299ceb..bfe07135a1 100644 --- a/library/init/meta/smt/congruence_closure.lean +++ b/library/init/meta/smt/congruence_closure.lean @@ -15,7 +15,7 @@ structure cc_config := *only* considered for the functions in fns and local functions. The performance overhead is described in the paper "Congruence Closure in Intensional Type Theory". If ho_fns is none, then full support is provided for *all* constants. -/ -(ho_fns : option (list name) := some []) +(ho_fns : option (list name) := none) /- If true, then use excluded middle -/ (em : bool := tt) diff --git a/src/library/tactic/smt/congruence_closure.cpp b/src/library/tactic/smt/congruence_closure.cpp index 93127949ab..eba0afa376 100644 --- a/src/library/tactic/smt/congruence_closure.cpp +++ b/src/library/tactic/smt/congruence_closure.cpp @@ -504,16 +504,23 @@ void congruence_closure::apply_simple_eqvs(expr const & e) { push_refl_eq(e, reduced_e); } - expr root_fn = get_root(fn); - auto it = get_entry(root_fn); - if (it && it->m_has_lambdas) { - buffer lambdas; - get_eqc_lambdas(root_fn, lambdas); - buffer new_lambda_apps; - propagate_beta(e, lambdas, new_lambda_apps); - for (expr const & new_app : new_lambda_apps) { - internalize_core(new_app, none_expr()); + buffer rev_args; + auto it = e; + while (is_app(it)) { + rev_args.push_back(app_arg(it)); + expr const & fn = app_fn(it); + expr root_fn = get_root(fn); + auto en = get_entry(root_fn); + if (en && en->m_has_lambdas) { + buffer lambdas; + get_eqc_lambdas(root_fn, lambdas); + buffer new_lambda_apps; + propagate_beta(fn, rev_args, lambdas, new_lambda_apps); + for (expr const & new_app : new_lambda_apps) { + internalize_core(new_app, none_expr()); + } } + it = fn; } propagate_up(e); @@ -1695,19 +1702,66 @@ void congruence_closure::get_eqc_lambdas(expr const & e, buffer & r) { } while (it != e); } -void congruence_closure::propagate_beta(expr const & e, buffer const & lambdas, buffer & new_lambda_apps) { - lean_assert(is_app(e)); - buffer args; - expr const & fn = get_app_args(e, args); +void congruence_closure::propagate_beta(expr const & fn, buffer const & rev_args, + buffer const & lambdas, buffer & new_lambda_apps) { for (expr const & lambda : lambdas) { lean_assert(is_lambda(lambda)); if (fn != lambda && m_ctx.relaxed_is_def_eq(m_ctx.infer(fn), m_ctx.infer(lambda))) { - expr new_app = mk_app(lambda, args); + expr new_app = mk_rev_app(lambda, rev_args); new_lambda_apps.push_back(new_app); } } } +/* Traverse the root's equivalence class, and collect the function's equivalence class roots. */ +void congruence_closure::collect_fn_roots(expr const & root, buffer & fn_roots) { + lean_assert(get_root(root) == root); + rb_expr_tree visited; + auto it = root; + do { + expr fn_root = get_root(get_app_fn(it)); + if (!visited.contains(fn_root)) { + visited.insert(fn_root); + fn_roots.push_back(fn_root); + } + auto it_n = get_entry(it); + it = it_n->m_next; + } while (it != root); +} + +/* For each fn_root in fn_roots traverse its parents, and look for a parent prefix that is + in the same equivalence class of the given lambdas. + + \remark All expressions in lambdas are in the same equivalence class */ +void congruence_closure::propagate_beta_to_eqc(buffer const & fn_roots, buffer const & lambdas, + buffer & new_lambda_apps) { + if (lambdas.empty()) return; + expr const & lambda_root = get_root(lambdas.back()); + lean_assert(std::all_of(lambdas.begin(), lambdas.end(), [&](expr const & l) { + return is_lambda(l) && get_root(l) == lambda_root; + })); + for (expr const & fn_root : fn_roots) { + if (auto ps = m_state.m_parents.find(fn_root)) { + ps->for_each([&](parent_occ const & p_occ) { + expr const & p = p_occ.m_expr; + /* Look for a prefix of p which is in the same equivalence class of lambda_root */ + buffer rev_args; + expr it2 = p; + while (is_app(it2)) { + expr const & fn = app_fn(it2); + rev_args.push_back(app_arg(it2)); + if (get_root(fn) == lambda_root) { + /* found it */ + propagate_beta(fn, rev_args, lambdas, new_lambda_apps); + break; + } + it2 = app_fn(it2); + } + }); + } + } +} + void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq_proof) { auto n1 = get_entry(e1); auto n2 = get_entry(e2); @@ -1785,6 +1839,9 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq buffer lambdas1, lambdas2; get_eqc_lambdas(e1_root, lambdas1); get_eqc_lambdas(e2_root, lambdas2); + buffer fn_roots1, fn_roots2; + if (!lambdas1.empty()) collect_fn_roots(e2_root, fn_roots2); + if (!lambdas2.empty()) collect_fn_roots(e1_root, fn_roots1); /* force all m_root fields in e1 equivalence class to point to e2_root */ bool propagate = is_true_or_false(e2_root); @@ -1826,17 +1883,8 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq lean_assert(check_invariant()); buffer lambda_apps_to_internalize; - - if (!lambdas1.empty()) { - // beta with e2_root parents and lambdas1 (of e1 class) - if (auto ps2 = m_state.m_parents.find(e2_root)) { - ps2->for_each([&](parent_occ const & p) { - if (is_app(p.m_expr) && get_root(get_app_fn(p.m_expr)) == e2) { - propagate_beta(p.m_expr, lambdas1, lambda_apps_to_internalize); - } - }); - } - } + propagate_beta_to_eqc(fn_roots2, lambdas1, lambda_apps_to_internalize); + propagate_beta_to_eqc(fn_roots1, lambdas2, lambda_apps_to_internalize); // copy e1_root parents to e2_root auto ps1 = m_state.m_parents.find(e1_root); @@ -1851,10 +1899,6 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq } ps2.insert(p); } - if (!lambdas2.empty() && is_app(p.m_expr) && get_root(get_app_fn(p.m_expr)) == e2) { - // beta with e1_root parents and lambdas2 (of e2 class) - propagate_beta(p.m_expr, lambdas2, lambda_apps_to_internalize); - } }); m_state.m_parents.erase(e1_root); m_state.m_parents.insert(e2_root, ps2); @@ -1899,7 +1943,8 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq auto fmt = out.get_formatter(); out << "merged: " << e1_root << " = " << e2_root << "\n"; out << m_state.pp_eqcs(fmt) << "\n"; - // out << m_state.pp_parent_occs(fmt) << "\n"; + if (is_trace_class_enabled(name{"debug", "cc", "parent_occs"})) + out << m_state.pp_parent_occs(fmt) << "\n"; out << "--------\n";); } @@ -2102,6 +2147,8 @@ void initialize_congruence_closure() { register_trace_class({"cc", "failure"}); register_trace_class({"cc", "merge"}); register_trace_class({"debug", "cc"}); + register_trace_class({"debug", "cc", "parent_occs"}); + name prefix = name::mk_internal_unique_name(); g_congr_mark = new expr(mk_constant(name(prefix, "[congruence]"))); g_eq_true_mark = new expr(mk_constant(name(prefix, "[iff-true]"))); diff --git a/src/library/tactic/smt/congruence_closure.h b/src/library/tactic/smt/congruence_closure.h index 9bfc12c057..91490bf98f 100644 --- a/src/library/tactic/smt/congruence_closure.h +++ b/src/library/tactic/smt/congruence_closure.h @@ -259,7 +259,9 @@ private: void propagate_projection_constructor(expr const & p, expr const & c); void propagate_value_inconsistency(expr const & e1, expr const & e2); void get_eqc_lambdas(expr const & e, buffer & r); - void propagate_beta(expr const & e, buffer const & lambdas, buffer & r); + void propagate_beta(expr const & fn, buffer const & rev_args, buffer const & lambdas, buffer & r); + void propagate_beta_to_eqc(buffer const & fn_roots, buffer const & lambdas, buffer & new_lambda_apps); + void collect_fn_roots(expr const & root, buffer & fn_roots); void add_eqv_step(expr e1, expr e2, expr const & H, bool heq_proof); void process_todo(); void add_eqv_core(expr const & lhs, expr const & rhs, expr const & H, bool heq_proof); diff --git a/tests/lean/run/cc_beta.lean b/tests/lean/run/cc_beta.lean new file mode 100644 index 0000000000..555838e5a9 --- /dev/null +++ b/tests/lean/run/cc_beta.lean @@ -0,0 +1,68 @@ +example (f : nat → nat → nat) (a b c : nat) : + f a = (λ x, x) → f a b = b := +begin [smt] + intros, +end + +example (f g : nat → nat → nat) (a b c : nat) : + f a = g c → f a = (λ x, x) → g c b = b := +begin [smt] + intros, +end + +constant F : nat → nat → nat +constant G : nat → nat → nat + +example (a b c : nat) : + F a = (λ x, x) → F a b = b := +begin [smt] + intros, +end + +example (a b c : nat) : + F a = G c → F a = (λ x, x) → G c b = b := +begin [smt] + intros, +end + +example (f : nat → nat → nat) (a b c : nat) : + f a b ≠ b → f a = (λ x, x) → false := +begin [smt] + intros, +end + +example (f g : nat → nat → nat) (a b c : nat) : + g c b ≠ b → f a = g c → f a = (λ x, x) → false := +begin [smt] + intros, +end + +example (f g : nat → nat → nat) (a b c : nat) : + f a = g c → g c b ≠ b → f a = (λ x, x) → false := +begin [smt] + intros, +end + +example (a b c : nat) : + F a b ≠ b → F a = (λ x, x) → false := +begin [smt] + intros, +end + +example (a b c : nat) : + G c b ≠ b → F a = G c → F a = (λ x, x) → false := +begin [smt] + intros, +end + +example (f : nat → nat → nat) (g : nat → nat → nat → nat) (a b c d : nat) : + g c d b ≠ b → f a = g c a → f a = (λ x, x) → d = a → false := +begin [smt] + intros, +end + +example (f : nat → nat → nat) (g : nat → nat → nat → nat) (a b c d : nat) : + d = a → g c d b ≠ b → f a = g c a → f a = (λ x, x) → false := +begin [smt] + intros, +end