feat(library/tactic/smt/congruence_closure): improve propagation for beta reduction in the congruence closure module

This commit is contained in:
Leonardo de Moura 2017-01-24 12:09:37 -08:00
parent 28ce1e6d2b
commit ac6bfce01c
4 changed files with 149 additions and 32 deletions

View file

@ -15,7 +15,7 @@ structure cc_config :=
*only* considered for the functions in fns and local functions. The performance overhead is described in the paper
"Congruence Closure in Intensional Type Theory". If ho_fns is none, then full support is provided
for *all* constants. -/
(ho_fns : option (list name) := some [])
(ho_fns : option (list name) := none)
/- If true, then use excluded middle -/
(em : bool := tt)

View file

@ -504,16 +504,23 @@ void congruence_closure::apply_simple_eqvs(expr const & e) {
push_refl_eq(e, reduced_e);
}
expr root_fn = get_root(fn);
auto it = get_entry(root_fn);
if (it && it->m_has_lambdas) {
buffer<expr> lambdas;
get_eqc_lambdas(root_fn, lambdas);
buffer<expr> new_lambda_apps;
propagate_beta(e, lambdas, new_lambda_apps);
for (expr const & new_app : new_lambda_apps) {
internalize_core(new_app, none_expr());
buffer<expr> rev_args;
auto it = e;
while (is_app(it)) {
rev_args.push_back(app_arg(it));
expr const & fn = app_fn(it);
expr root_fn = get_root(fn);
auto en = get_entry(root_fn);
if (en && en->m_has_lambdas) {
buffer<expr> lambdas;
get_eqc_lambdas(root_fn, lambdas);
buffer<expr> new_lambda_apps;
propagate_beta(fn, rev_args, lambdas, new_lambda_apps);
for (expr const & new_app : new_lambda_apps) {
internalize_core(new_app, none_expr());
}
}
it = fn;
}
propagate_up(e);
@ -1695,19 +1702,66 @@ void congruence_closure::get_eqc_lambdas(expr const & e, buffer<expr> & r) {
} while (it != e);
}
void congruence_closure::propagate_beta(expr const & e, buffer<expr> const & lambdas, buffer<expr> & new_lambda_apps) {
lean_assert(is_app(e));
buffer<expr> args;
expr const & fn = get_app_args(e, args);
void congruence_closure::propagate_beta(expr const & fn, buffer<expr> const & rev_args,
buffer<expr> const & lambdas, buffer<expr> & new_lambda_apps) {
for (expr const & lambda : lambdas) {
lean_assert(is_lambda(lambda));
if (fn != lambda && m_ctx.relaxed_is_def_eq(m_ctx.infer(fn), m_ctx.infer(lambda))) {
expr new_app = mk_app(lambda, args);
expr new_app = mk_rev_app(lambda, rev_args);
new_lambda_apps.push_back(new_app);
}
}
}
/* Traverse the root's equivalence class, and collect the function's equivalence class roots. */
void congruence_closure::collect_fn_roots(expr const & root, buffer<expr> & fn_roots) {
lean_assert(get_root(root) == root);
rb_expr_tree visited;
auto it = root;
do {
expr fn_root = get_root(get_app_fn(it));
if (!visited.contains(fn_root)) {
visited.insert(fn_root);
fn_roots.push_back(fn_root);
}
auto it_n = get_entry(it);
it = it_n->m_next;
} while (it != root);
}
/* For each fn_root in fn_roots traverse its parents, and look for a parent prefix that is
in the same equivalence class of the given lambdas.
\remark All expressions in lambdas are in the same equivalence class */
void congruence_closure::propagate_beta_to_eqc(buffer<expr> const & fn_roots, buffer<expr> const & lambdas,
buffer<expr> & new_lambda_apps) {
if (lambdas.empty()) return;
expr const & lambda_root = get_root(lambdas.back());
lean_assert(std::all_of(lambdas.begin(), lambdas.end(), [&](expr const & l) {
return is_lambda(l) && get_root(l) == lambda_root;
}));
for (expr const & fn_root : fn_roots) {
if (auto ps = m_state.m_parents.find(fn_root)) {
ps->for_each([&](parent_occ const & p_occ) {
expr const & p = p_occ.m_expr;
/* Look for a prefix of p which is in the same equivalence class of lambda_root */
buffer<expr> rev_args;
expr it2 = p;
while (is_app(it2)) {
expr const & fn = app_fn(it2);
rev_args.push_back(app_arg(it2));
if (get_root(fn) == lambda_root) {
/* found it */
propagate_beta(fn, rev_args, lambdas, new_lambda_apps);
break;
}
it2 = app_fn(it2);
}
});
}
}
}
void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq_proof) {
auto n1 = get_entry(e1);
auto n2 = get_entry(e2);
@ -1785,6 +1839,9 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq
buffer<expr> lambdas1, lambdas2;
get_eqc_lambdas(e1_root, lambdas1);
get_eqc_lambdas(e2_root, lambdas2);
buffer<expr> fn_roots1, fn_roots2;
if (!lambdas1.empty()) collect_fn_roots(e2_root, fn_roots2);
if (!lambdas2.empty()) collect_fn_roots(e1_root, fn_roots1);
/* force all m_root fields in e1 equivalence class to point to e2_root */
bool propagate = is_true_or_false(e2_root);
@ -1826,17 +1883,8 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq
lean_assert(check_invariant());
buffer<expr> lambda_apps_to_internalize;
if (!lambdas1.empty()) {
// beta with e2_root parents and lambdas1 (of e1 class)
if (auto ps2 = m_state.m_parents.find(e2_root)) {
ps2->for_each([&](parent_occ const & p) {
if (is_app(p.m_expr) && get_root(get_app_fn(p.m_expr)) == e2) {
propagate_beta(p.m_expr, lambdas1, lambda_apps_to_internalize);
}
});
}
}
propagate_beta_to_eqc(fn_roots2, lambdas1, lambda_apps_to_internalize);
propagate_beta_to_eqc(fn_roots1, lambdas2, lambda_apps_to_internalize);
// copy e1_root parents to e2_root
auto ps1 = m_state.m_parents.find(e1_root);
@ -1851,10 +1899,6 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq
}
ps2.insert(p);
}
if (!lambdas2.empty() && is_app(p.m_expr) && get_root(get_app_fn(p.m_expr)) == e2) {
// beta with e1_root parents and lambdas2 (of e2 class)
propagate_beta(p.m_expr, lambdas2, lambda_apps_to_internalize);
}
});
m_state.m_parents.erase(e1_root);
m_state.m_parents.insert(e2_root, ps2);
@ -1899,7 +1943,8 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq
auto fmt = out.get_formatter();
out << "merged: " << e1_root << " = " << e2_root << "\n";
out << m_state.pp_eqcs(fmt) << "\n";
// out << m_state.pp_parent_occs(fmt) << "\n";
if (is_trace_class_enabled(name{"debug", "cc", "parent_occs"}))
out << m_state.pp_parent_occs(fmt) << "\n";
out << "--------\n";);
}
@ -2102,6 +2147,8 @@ void initialize_congruence_closure() {
register_trace_class({"cc", "failure"});
register_trace_class({"cc", "merge"});
register_trace_class({"debug", "cc"});
register_trace_class({"debug", "cc", "parent_occs"});
name prefix = name::mk_internal_unique_name();
g_congr_mark = new expr(mk_constant(name(prefix, "[congruence]")));
g_eq_true_mark = new expr(mk_constant(name(prefix, "[iff-true]")));

View file

@ -259,7 +259,9 @@ private:
void propagate_projection_constructor(expr const & p, expr const & c);
void propagate_value_inconsistency(expr const & e1, expr const & e2);
void get_eqc_lambdas(expr const & e, buffer<expr> & r);
void propagate_beta(expr const & e, buffer<expr> const & lambdas, buffer<expr> & r);
void propagate_beta(expr const & fn, buffer<expr> const & rev_args, buffer<expr> const & lambdas, buffer<expr> & r);
void propagate_beta_to_eqc(buffer<expr> const & fn_roots, buffer<expr> const & lambdas, buffer<expr> & new_lambda_apps);
void collect_fn_roots(expr const & root, buffer<expr> & fn_roots);
void add_eqv_step(expr e1, expr e2, expr const & H, bool heq_proof);
void process_todo();
void add_eqv_core(expr const & lhs, expr const & rhs, expr const & H, bool heq_proof);

View file

@ -0,0 +1,68 @@
example (f : nat → nat → nat) (a b c : nat) :
f a = (λ x, x) → f a b = b :=
begin [smt]
intros,
end
example (f g : nat → nat → nat) (a b c : nat) :
f a = g c → f a = (λ x, x) → g c b = b :=
begin [smt]
intros,
end
constant F : nat → nat → nat
constant G : nat → nat → nat
example (a b c : nat) :
F a = (λ x, x) → F a b = b :=
begin [smt]
intros,
end
example (a b c : nat) :
F a = G c → F a = (λ x, x) → G c b = b :=
begin [smt]
intros,
end
example (f : nat → nat → nat) (a b c : nat) :
f a b ≠ b → f a = (λ x, x) → false :=
begin [smt]
intros,
end
example (f g : nat → nat → nat) (a b c : nat) :
g c b ≠ b → f a = g c → f a = (λ x, x) → false :=
begin [smt]
intros,
end
example (f g : nat → nat → nat) (a b c : nat) :
f a = g c → g c b ≠ b → f a = (λ x, x) → false :=
begin [smt]
intros,
end
example (a b c : nat) :
F a b ≠ b → F a = (λ x, x) → false :=
begin [smt]
intros,
end
example (a b c : nat) :
G c b ≠ b → F a = G c → F a = (λ x, x) → false :=
begin [smt]
intros,
end
example (f : nat → nat → nat) (g : nat → nat → nat → nat) (a b c d : nat) :
g c d b ≠ b → f a = g c a → f a = (λ x, x) → d = a → false :=
begin [smt]
intros,
end
example (f : nat → nat → nat) (g : nat → nat → nat → nat) (a b c d : nat) :
d = a → g c d b ≠ b → f a = g c a → f a = (λ x, x) → false :=
begin [smt]
intros,
end