feat(library/tactic/smt/congruence_closure): improve propagation for beta reduction in the congruence closure module
This commit is contained in:
parent
28ce1e6d2b
commit
ac6bfce01c
4 changed files with 149 additions and 32 deletions
|
|
@ -15,7 +15,7 @@ structure cc_config :=
|
|||
*only* considered for the functions in fns and local functions. The performance overhead is described in the paper
|
||||
"Congruence Closure in Intensional Type Theory". If ho_fns is none, then full support is provided
|
||||
for *all* constants. -/
|
||||
(ho_fns : option (list name) := some [])
|
||||
(ho_fns : option (list name) := none)
|
||||
/- If true, then use excluded middle -/
|
||||
(em : bool := tt)
|
||||
|
||||
|
|
|
|||
|
|
@ -504,16 +504,23 @@ void congruence_closure::apply_simple_eqvs(expr const & e) {
|
|||
push_refl_eq(e, reduced_e);
|
||||
}
|
||||
|
||||
expr root_fn = get_root(fn);
|
||||
auto it = get_entry(root_fn);
|
||||
if (it && it->m_has_lambdas) {
|
||||
buffer<expr> lambdas;
|
||||
get_eqc_lambdas(root_fn, lambdas);
|
||||
buffer<expr> new_lambda_apps;
|
||||
propagate_beta(e, lambdas, new_lambda_apps);
|
||||
for (expr const & new_app : new_lambda_apps) {
|
||||
internalize_core(new_app, none_expr());
|
||||
buffer<expr> rev_args;
|
||||
auto it = e;
|
||||
while (is_app(it)) {
|
||||
rev_args.push_back(app_arg(it));
|
||||
expr const & fn = app_fn(it);
|
||||
expr root_fn = get_root(fn);
|
||||
auto en = get_entry(root_fn);
|
||||
if (en && en->m_has_lambdas) {
|
||||
buffer<expr> lambdas;
|
||||
get_eqc_lambdas(root_fn, lambdas);
|
||||
buffer<expr> new_lambda_apps;
|
||||
propagate_beta(fn, rev_args, lambdas, new_lambda_apps);
|
||||
for (expr const & new_app : new_lambda_apps) {
|
||||
internalize_core(new_app, none_expr());
|
||||
}
|
||||
}
|
||||
it = fn;
|
||||
}
|
||||
|
||||
propagate_up(e);
|
||||
|
|
@ -1695,19 +1702,66 @@ void congruence_closure::get_eqc_lambdas(expr const & e, buffer<expr> & r) {
|
|||
} while (it != e);
|
||||
}
|
||||
|
||||
void congruence_closure::propagate_beta(expr const & e, buffer<expr> const & lambdas, buffer<expr> & new_lambda_apps) {
|
||||
lean_assert(is_app(e));
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(e, args);
|
||||
void congruence_closure::propagate_beta(expr const & fn, buffer<expr> const & rev_args,
|
||||
buffer<expr> const & lambdas, buffer<expr> & new_lambda_apps) {
|
||||
for (expr const & lambda : lambdas) {
|
||||
lean_assert(is_lambda(lambda));
|
||||
if (fn != lambda && m_ctx.relaxed_is_def_eq(m_ctx.infer(fn), m_ctx.infer(lambda))) {
|
||||
expr new_app = mk_app(lambda, args);
|
||||
expr new_app = mk_rev_app(lambda, rev_args);
|
||||
new_lambda_apps.push_back(new_app);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Traverse the root's equivalence class, and collect the function's equivalence class roots. */
|
||||
void congruence_closure::collect_fn_roots(expr const & root, buffer<expr> & fn_roots) {
|
||||
lean_assert(get_root(root) == root);
|
||||
rb_expr_tree visited;
|
||||
auto it = root;
|
||||
do {
|
||||
expr fn_root = get_root(get_app_fn(it));
|
||||
if (!visited.contains(fn_root)) {
|
||||
visited.insert(fn_root);
|
||||
fn_roots.push_back(fn_root);
|
||||
}
|
||||
auto it_n = get_entry(it);
|
||||
it = it_n->m_next;
|
||||
} while (it != root);
|
||||
}
|
||||
|
||||
/* For each fn_root in fn_roots traverse its parents, and look for a parent prefix that is
|
||||
in the same equivalence class of the given lambdas.
|
||||
|
||||
\remark All expressions in lambdas are in the same equivalence class */
|
||||
void congruence_closure::propagate_beta_to_eqc(buffer<expr> const & fn_roots, buffer<expr> const & lambdas,
|
||||
buffer<expr> & new_lambda_apps) {
|
||||
if (lambdas.empty()) return;
|
||||
expr const & lambda_root = get_root(lambdas.back());
|
||||
lean_assert(std::all_of(lambdas.begin(), lambdas.end(), [&](expr const & l) {
|
||||
return is_lambda(l) && get_root(l) == lambda_root;
|
||||
}));
|
||||
for (expr const & fn_root : fn_roots) {
|
||||
if (auto ps = m_state.m_parents.find(fn_root)) {
|
||||
ps->for_each([&](parent_occ const & p_occ) {
|
||||
expr const & p = p_occ.m_expr;
|
||||
/* Look for a prefix of p which is in the same equivalence class of lambda_root */
|
||||
buffer<expr> rev_args;
|
||||
expr it2 = p;
|
||||
while (is_app(it2)) {
|
||||
expr const & fn = app_fn(it2);
|
||||
rev_args.push_back(app_arg(it2));
|
||||
if (get_root(fn) == lambda_root) {
|
||||
/* found it */
|
||||
propagate_beta(fn, rev_args, lambdas, new_lambda_apps);
|
||||
break;
|
||||
}
|
||||
it2 = app_fn(it2);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq_proof) {
|
||||
auto n1 = get_entry(e1);
|
||||
auto n2 = get_entry(e2);
|
||||
|
|
@ -1785,6 +1839,9 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq
|
|||
buffer<expr> lambdas1, lambdas2;
|
||||
get_eqc_lambdas(e1_root, lambdas1);
|
||||
get_eqc_lambdas(e2_root, lambdas2);
|
||||
buffer<expr> fn_roots1, fn_roots2;
|
||||
if (!lambdas1.empty()) collect_fn_roots(e2_root, fn_roots2);
|
||||
if (!lambdas2.empty()) collect_fn_roots(e1_root, fn_roots1);
|
||||
|
||||
/* force all m_root fields in e1 equivalence class to point to e2_root */
|
||||
bool propagate = is_true_or_false(e2_root);
|
||||
|
|
@ -1826,17 +1883,8 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq
|
|||
lean_assert(check_invariant());
|
||||
|
||||
buffer<expr> lambda_apps_to_internalize;
|
||||
|
||||
if (!lambdas1.empty()) {
|
||||
// beta with e2_root parents and lambdas1 (of e1 class)
|
||||
if (auto ps2 = m_state.m_parents.find(e2_root)) {
|
||||
ps2->for_each([&](parent_occ const & p) {
|
||||
if (is_app(p.m_expr) && get_root(get_app_fn(p.m_expr)) == e2) {
|
||||
propagate_beta(p.m_expr, lambdas1, lambda_apps_to_internalize);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
propagate_beta_to_eqc(fn_roots2, lambdas1, lambda_apps_to_internalize);
|
||||
propagate_beta_to_eqc(fn_roots1, lambdas2, lambda_apps_to_internalize);
|
||||
|
||||
// copy e1_root parents to e2_root
|
||||
auto ps1 = m_state.m_parents.find(e1_root);
|
||||
|
|
@ -1851,10 +1899,6 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq
|
|||
}
|
||||
ps2.insert(p);
|
||||
}
|
||||
if (!lambdas2.empty() && is_app(p.m_expr) && get_root(get_app_fn(p.m_expr)) == e2) {
|
||||
// beta with e1_root parents and lambdas2 (of e2 class)
|
||||
propagate_beta(p.m_expr, lambdas2, lambda_apps_to_internalize);
|
||||
}
|
||||
});
|
||||
m_state.m_parents.erase(e1_root);
|
||||
m_state.m_parents.insert(e2_root, ps2);
|
||||
|
|
@ -1899,7 +1943,8 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, bool heq
|
|||
auto fmt = out.get_formatter();
|
||||
out << "merged: " << e1_root << " = " << e2_root << "\n";
|
||||
out << m_state.pp_eqcs(fmt) << "\n";
|
||||
// out << m_state.pp_parent_occs(fmt) << "\n";
|
||||
if (is_trace_class_enabled(name{"debug", "cc", "parent_occs"}))
|
||||
out << m_state.pp_parent_occs(fmt) << "\n";
|
||||
out << "--------\n";);
|
||||
}
|
||||
|
||||
|
|
@ -2102,6 +2147,8 @@ void initialize_congruence_closure() {
|
|||
register_trace_class({"cc", "failure"});
|
||||
register_trace_class({"cc", "merge"});
|
||||
register_trace_class({"debug", "cc"});
|
||||
register_trace_class({"debug", "cc", "parent_occs"});
|
||||
|
||||
name prefix = name::mk_internal_unique_name();
|
||||
g_congr_mark = new expr(mk_constant(name(prefix, "[congruence]")));
|
||||
g_eq_true_mark = new expr(mk_constant(name(prefix, "[iff-true]")));
|
||||
|
|
|
|||
|
|
@ -259,7 +259,9 @@ private:
|
|||
void propagate_projection_constructor(expr const & p, expr const & c);
|
||||
void propagate_value_inconsistency(expr const & e1, expr const & e2);
|
||||
void get_eqc_lambdas(expr const & e, buffer<expr> & r);
|
||||
void propagate_beta(expr const & e, buffer<expr> const & lambdas, buffer<expr> & r);
|
||||
void propagate_beta(expr const & fn, buffer<expr> const & rev_args, buffer<expr> const & lambdas, buffer<expr> & r);
|
||||
void propagate_beta_to_eqc(buffer<expr> const & fn_roots, buffer<expr> const & lambdas, buffer<expr> & new_lambda_apps);
|
||||
void collect_fn_roots(expr const & root, buffer<expr> & fn_roots);
|
||||
void add_eqv_step(expr e1, expr e2, expr const & H, bool heq_proof);
|
||||
void process_todo();
|
||||
void add_eqv_core(expr const & lhs, expr const & rhs, expr const & H, bool heq_proof);
|
||||
|
|
|
|||
68
tests/lean/run/cc_beta.lean
Normal file
68
tests/lean/run/cc_beta.lean
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
example (f : nat → nat → nat) (a b c : nat) :
|
||||
f a = (λ x, x) → f a b = b :=
|
||||
begin [smt]
|
||||
intros,
|
||||
end
|
||||
|
||||
example (f g : nat → nat → nat) (a b c : nat) :
|
||||
f a = g c → f a = (λ x, x) → g c b = b :=
|
||||
begin [smt]
|
||||
intros,
|
||||
end
|
||||
|
||||
constant F : nat → nat → nat
|
||||
constant G : nat → nat → nat
|
||||
|
||||
example (a b c : nat) :
|
||||
F a = (λ x, x) → F a b = b :=
|
||||
begin [smt]
|
||||
intros,
|
||||
end
|
||||
|
||||
example (a b c : nat) :
|
||||
F a = G c → F a = (λ x, x) → G c b = b :=
|
||||
begin [smt]
|
||||
intros,
|
||||
end
|
||||
|
||||
example (f : nat → nat → nat) (a b c : nat) :
|
||||
f a b ≠ b → f a = (λ x, x) → false :=
|
||||
begin [smt]
|
||||
intros,
|
||||
end
|
||||
|
||||
example (f g : nat → nat → nat) (a b c : nat) :
|
||||
g c b ≠ b → f a = g c → f a = (λ x, x) → false :=
|
||||
begin [smt]
|
||||
intros,
|
||||
end
|
||||
|
||||
example (f g : nat → nat → nat) (a b c : nat) :
|
||||
f a = g c → g c b ≠ b → f a = (λ x, x) → false :=
|
||||
begin [smt]
|
||||
intros,
|
||||
end
|
||||
|
||||
example (a b c : nat) :
|
||||
F a b ≠ b → F a = (λ x, x) → false :=
|
||||
begin [smt]
|
||||
intros,
|
||||
end
|
||||
|
||||
example (a b c : nat) :
|
||||
G c b ≠ b → F a = G c → F a = (λ x, x) → false :=
|
||||
begin [smt]
|
||||
intros,
|
||||
end
|
||||
|
||||
example (f : nat → nat → nat) (g : nat → nat → nat → nat) (a b c d : nat) :
|
||||
g c d b ≠ b → f a = g c a → f a = (λ x, x) → d = a → false :=
|
||||
begin [smt]
|
||||
intros,
|
||||
end
|
||||
|
||||
example (f : nat → nat → nat) (g : nat → nat → nat → nat) (a b c d : nat) :
|
||||
d = a → g c d b ≠ b → f a = g c a → f a = (λ x, x) → false :=
|
||||
begin [smt]
|
||||
intros,
|
||||
end
|
||||
Loading…
Add table
Reference in a new issue