feat(library/blast/congruence_closure): more general congruence lemmas
This commit is contained in:
parent
d9294fc164
commit
3f7122ce07
3 changed files with 348 additions and 192 deletions
|
|
@ -613,6 +613,70 @@ int congruence_closure::compare_root(name const & R, expr e1, expr e2) const {
|
|||
return expr_quick_cmp()(e1, e2);
|
||||
}
|
||||
|
||||
int congruence_closure::eq_congr_key_cmp::operator()(eq_congr_key const & k1, eq_congr_key const & k2) const {
|
||||
lean_assert(g_heq_based);
|
||||
if (k1.m_hash != k2.m_hash)
|
||||
return unsigned_cmp()(k1.m_hash, k2.m_hash);
|
||||
lean_assert(is_app(k1.m_expr) && is_app(k2.m_expr));
|
||||
expr const & f1 = app_fn(k1.m_expr);
|
||||
expr const & a1 = app_arg(k1.m_expr);
|
||||
expr const & f2 = app_fn(k2.m_expr);
|
||||
expr const & a2 = app_arg(k2.m_expr);
|
||||
int r = g_cc->compare_root(get_eq_name(), a1, a2);
|
||||
if (r != 0) return r;
|
||||
r = g_cc->compare_root(get_eq_name(), f1, f2);
|
||||
if (r != 0) return r;
|
||||
if (is_app(f1) && is_app(f2)) return 0; /* composite */
|
||||
if (f1 == f2) return 0; /* identical functions */
|
||||
expr f1_type = infer_type(f1);
|
||||
expr f2_type = infer_type(f2);
|
||||
if (is_def_eq(f1_type, f2_type)) return 0; /* same function type */
|
||||
/*
|
||||
f1 and f2 are not applications and have different types.
|
||||
We can't generate a congruence proof in this case because the following lemma
|
||||
|
||||
hcongr : f1 == f2 -> a1 == a2 -> f1 a1 == f2 a2
|
||||
|
||||
is not provable. Remark: it is also not provable in MLTT, Coq and Agda (even if we assume UIP).
|
||||
*/
|
||||
return expr_quick_cmp()(f1, f2);
|
||||
}
|
||||
|
||||
/* \brief Create a equality congruence table key.
|
||||
\remark This table and key are only used when heterogeneous equality support is enabled. */
|
||||
auto congruence_closure::mk_eq_congr_key(expr const & e) const -> eq_congr_key {
|
||||
lean_assert(is_app(e));
|
||||
eq_congr_key k;
|
||||
k.m_expr = e;
|
||||
expr const & f = app_fn(e);
|
||||
expr const & a = app_arg(e);
|
||||
unsigned h = hash(get_root(get_eq_name(), f).hash(), get_root(get_eq_name(), a).hash());
|
||||
k.m_hash = h;
|
||||
return k;
|
||||
}
|
||||
|
||||
int congruence_closure::cmp_eq_iff_keys(congr_key const & k1, congr_key const & k2) const {
|
||||
lean_assert(k1.m_eq == k2.m_eq);
|
||||
lean_assert(k1.m_iff == k2.m_iff);
|
||||
lean_assert(k1.m_eq != k1.m_iff);
|
||||
name const & R = k1.m_eq ? get_eq_name() : get_iff_name();
|
||||
expr const & lhs1 = app_arg(app_fn(k1.m_expr));
|
||||
expr const & rhs1 = app_arg(k1.m_expr);
|
||||
expr const & lhs2 = app_arg(app_fn(k2.m_expr));
|
||||
expr const & rhs2 = app_arg(k2.m_expr);
|
||||
return compare_symm(R, lhs1, rhs1, lhs2, rhs2);
|
||||
}
|
||||
|
||||
int congruence_closure::cmp_symm_rel_keys(congr_key const & k1, congr_key const & k2) const {
|
||||
name R1, R2;
|
||||
expr lhs1, rhs1, lhs2, rhs2;
|
||||
lean_verify(is_equivalence_relation_app(k1.m_expr, R1, lhs1, rhs1));
|
||||
lean_verify(is_equivalence_relation_app(k2.m_expr, R2, lhs2, rhs2));
|
||||
if (R1 != R2)
|
||||
return quick_cmp(R1, R2);
|
||||
return compare_symm(R1, lhs1, rhs1, lhs2, rhs2);
|
||||
}
|
||||
|
||||
int congruence_closure::congr_key_cmp::operator()(congr_key const & k1, congr_key const & k2) const {
|
||||
if (k1.m_hash != k2.m_hash)
|
||||
return unsigned_cmp()(k1.m_hash, k2.m_hash);
|
||||
|
|
@ -624,67 +688,55 @@ int congruence_closure::congr_key_cmp::operator()(congr_key const & k1, congr_ke
|
|||
return k1.m_iff ? -1 : 1;
|
||||
if (k2.m_symm_rel != k2.m_symm_rel)
|
||||
return k1.m_symm_rel ? -1 : 1;
|
||||
if (k1.m_eq || k1.m_iff) {
|
||||
name const & R = k1.m_eq ? get_eq_name() : get_iff_name();
|
||||
expr const & lhs1 = app_arg(app_fn(k1.m_expr));
|
||||
expr const & rhs1 = app_arg(k1.m_expr);
|
||||
expr const & lhs2 = app_arg(app_fn(k2.m_expr));
|
||||
expr const & rhs2 = app_arg(k2.m_expr);
|
||||
return g_cc->compare_symm(R, lhs1, rhs1, lhs2, rhs2);
|
||||
} else if (k1.m_symm_rel) {
|
||||
name R1, R2;
|
||||
expr lhs1, rhs1, lhs2, rhs2;
|
||||
lean_verify(is_equivalence_relation_app(k1.m_expr, R1, lhs1, rhs1));
|
||||
lean_verify(is_equivalence_relation_app(k2.m_expr, R2, lhs2, rhs2));
|
||||
if (R1 != R2)
|
||||
return quick_cmp(R1, R2);
|
||||
return g_cc->compare_symm(R1, lhs1, rhs1, lhs2, rhs2);
|
||||
} else {
|
||||
lean_assert(!k1.m_eq && !k2.m_eq && !k1.m_iff && !k2.m_iff &&
|
||||
!k1.m_symm_rel && !k2.m_symm_rel);
|
||||
lean_assert(k1.m_R == k2.m_R);
|
||||
buffer<expr> args1, args2;
|
||||
expr const & fn1 = get_app_args(k1.m_expr, args1);
|
||||
expr const & fn2 = get_app_args(k2.m_expr, args2);
|
||||
if (args1.size() != args2.size())
|
||||
return unsigned_cmp()(args1.size(), args2.size());
|
||||
auto lemma = mk_ext_congr_lemma(k1.m_R, k1.m_expr);
|
||||
lean_assert(lemma);
|
||||
if (!lemma->m_fixed_fun) {
|
||||
int r = g_cc->compare_root(get_eq_name(), fn1, fn2);
|
||||
if (k1.m_eq || k1.m_iff)
|
||||
return g_cc->cmp_eq_iff_keys(k1, k2);
|
||||
if (k1.m_symm_rel)
|
||||
return g_cc->cmp_symm_rel_keys(k1, k2);
|
||||
|
||||
lean_assert(!k1.m_eq && !k2.m_eq && !k1.m_iff && !k2.m_iff &&
|
||||
!k1.m_symm_rel && !k2.m_symm_rel);
|
||||
lean_assert(k1.m_R == k2.m_R);
|
||||
buffer<expr> args1, args2;
|
||||
expr const & fn1 = get_app_args(k1.m_expr, args1);
|
||||
expr const & fn2 = get_app_args(k2.m_expr, args2);
|
||||
if (args1.size() != args2.size())
|
||||
return unsigned_cmp()(args1.size(), args2.size());
|
||||
auto lemma = mk_ext_congr_lemma(k1.m_R, k1.m_expr);
|
||||
lean_assert(lemma);
|
||||
if (!lemma->m_fixed_fun) {
|
||||
int r = g_cc->compare_root(get_eq_name(), fn1, fn2);
|
||||
if (r != 0) return r;
|
||||
for (unsigned i = 0; i < args1.size(); i++) {
|
||||
r = g_cc->compare_root(get_eq_name(), args1[i], args2[i]);
|
||||
if (r != 0) return r;
|
||||
for (unsigned i = 0; i < args1.size(); i++) {
|
||||
r = g_cc->compare_root(get_eq_name(), args1[i], args2[i]);
|
||||
if (r != 0) return r;
|
||||
}
|
||||
return 0;
|
||||
} else {
|
||||
list<optional<name>> const * it1 = &lemma->m_rel_names;
|
||||
list<congr_arg_kind> const * it2 = &lemma->m_congr_lemma.get_arg_kinds();
|
||||
int r;
|
||||
for (unsigned i = 0; i < args1.size(); i++) {
|
||||
lean_assert(*it1); lean_assert(*it2);
|
||||
switch (head(*it2)) {
|
||||
case congr_arg_kind::HEq:
|
||||
case congr_arg_kind::Eq:
|
||||
lean_assert(head(*it1));
|
||||
r = g_cc->compare_root(*head(*it1), args1[i], args2[i]);
|
||||
if (r != 0) return r;
|
||||
break;
|
||||
case congr_arg_kind::Fixed:
|
||||
case congr_arg_kind::FixedNoParam:
|
||||
r = expr_quick_cmp()(args1[i], args2[i]);
|
||||
if (r != 0) return r;
|
||||
break;
|
||||
case congr_arg_kind::Cast:
|
||||
// do nothing... ignore argument
|
||||
break;
|
||||
}
|
||||
it1 = &(tail(*it1));
|
||||
it2 = &(tail(*it2));
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
return 0;
|
||||
} else {
|
||||
list<optional<name>> const * it1 = &lemma->m_rel_names;
|
||||
list<congr_arg_kind> const * it2 = &lemma->m_congr_lemma.get_arg_kinds();
|
||||
int r;
|
||||
for (unsigned i = 0; i < args1.size(); i++) {
|
||||
lean_assert(*it1); lean_assert(*it2);
|
||||
switch (head(*it2)) {
|
||||
case congr_arg_kind::HEq:
|
||||
case congr_arg_kind::Eq:
|
||||
lean_assert(head(*it1));
|
||||
r = g_cc->compare_root(*head(*it1), args1[i], args2[i]);
|
||||
if (r != 0) return r;
|
||||
break;
|
||||
case congr_arg_kind::Fixed:
|
||||
case congr_arg_kind::FixedNoParam:
|
||||
r = expr_quick_cmp()(args1[i], args2[i]);
|
||||
if (r != 0) return r;
|
||||
break;
|
||||
case congr_arg_kind::Cast:
|
||||
// do nothing... ignore argument
|
||||
break;
|
||||
}
|
||||
it1 = &(tail(*it1));
|
||||
it2 = &(tail(*it2));
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -773,17 +825,41 @@ void congruence_closure::check_iff_true(congr_key const & k) {
|
|||
push_todo(get_iff_name(), e, mk_true(), *g_iff_true_mark, heq_proof);
|
||||
}
|
||||
|
||||
void congruence_closure::add_eq_congruence_table(expr const & e) {
|
||||
lean_assert(is_app(e));
|
||||
lean_assert(g_heq_based);
|
||||
eq_congr_key k = mk_eq_congr_key(e);
|
||||
if (auto old_k = m_eq_congruences.find(k)) {
|
||||
/*
|
||||
Found new equivalence: e ~ old_k->m_expr
|
||||
1. Update m_cg_root field for e
|
||||
*/
|
||||
eqc_key k(get_eq_name(), e);
|
||||
entry new_entry = *m_entries.find(k);
|
||||
new_entry.m_cg_root = old_k->m_expr;
|
||||
m_entries.insert(k, new_entry);
|
||||
/* 2. Put new equivalence in the TODO queue */
|
||||
/* TODO(Leo): check if the following line is a bottleneck */
|
||||
bool heq_proof = !is_def_eq(infer_type(e), infer_type(old_k->m_expr));
|
||||
push_todo(get_eq_name(), e, old_k->m_expr, *g_congr_mark, heq_proof);
|
||||
} else {
|
||||
m_eq_congruences.insert(k);
|
||||
}
|
||||
}
|
||||
|
||||
void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, expr const & e) {
|
||||
lean_assert(is_app(e));
|
||||
congr_key k = mk_congr_key(lemma, e);
|
||||
if (auto old_k = m_congruences.find(k)) {
|
||||
// Found new equivalence: e ~ old_k->m_expr
|
||||
// 1. Update m_cg_root field for e
|
||||
/*
|
||||
Found new equivalence: e ~ old_k->m_expr
|
||||
1. Update m_cg_root field for e
|
||||
*/
|
||||
eqc_key k(lemma.m_R, e);
|
||||
entry new_entry = *m_entries.find(k);
|
||||
new_entry.m_cg_root = old_k->m_expr;
|
||||
m_entries.insert(k, new_entry);
|
||||
// 2. Put new equivalence in the TODO queue
|
||||
/* 2. Put new equivalence in the TODO queue */
|
||||
bool heq_proof = false;
|
||||
if (lemma.m_heq_result) {
|
||||
lean_assert(g_heq_based);
|
||||
|
|
@ -893,32 +969,44 @@ void congruence_closure::internalize_core(name R, expr const & e, bool toplevel,
|
|||
case expr_kind::App: {
|
||||
bool is_lapp = is_logical_app(e);
|
||||
mk_entry_core(R, e, to_propagate && !is_lapp);
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(e, args);
|
||||
if (toplevel) {
|
||||
if (is_lapp) {
|
||||
to_propagate = true; // we must propagate the children of a top-level logical app (or, and, iff, ite)
|
||||
} else {
|
||||
toplevel = false; // children of a non-logical application will not be marked as toplevel
|
||||
}
|
||||
if (R == get_eq_name() && g_heq_based) {
|
||||
bool toplevel = false;
|
||||
bool to_propagate = false;
|
||||
internalize_core(R, app_fn(e), toplevel, to_propagate);
|
||||
internalize_core(R, app_arg(e), toplevel, to_propagate);
|
||||
add_occurrence(R, e, R, app_fn(e));
|
||||
add_occurrence(R, e, R, app_arg(e));
|
||||
add_eq_congruence_table(e);
|
||||
} else {
|
||||
to_propagate = false;
|
||||
}
|
||||
if (auto lemma = mk_ext_congr_lemma(R, e)) {
|
||||
list<optional<name>> const * it = &(lemma->m_rel_names);
|
||||
for (expr const & arg : args) {
|
||||
lean_assert(*it);
|
||||
if (auto R1 = head(*it)) {
|
||||
internalize_core(*R1, arg, toplevel, to_propagate);
|
||||
add_occurrence(R, e, *R1, arg);
|
||||
/* Handle user-defined congruence lemmas, congruence lemmas for other relations,
|
||||
and automatically generated lemmas for weakly-dependent-functions and relations. */
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(e, args);
|
||||
if (toplevel) {
|
||||
if (is_lapp) {
|
||||
to_propagate = true; // we must propagate the children of a top-level logical app (or, and, iff, ite)
|
||||
} else {
|
||||
toplevel = false; // children of a non-logical application will not be marked as toplevel
|
||||
}
|
||||
it = &tail(*it);
|
||||
} else {
|
||||
to_propagate = false;
|
||||
}
|
||||
if (!lemma->m_fixed_fun) {
|
||||
internalize_core(get_eq_name(), fn, false, false);
|
||||
add_occurrence(get_eq_name(), e, get_eq_name(), fn);
|
||||
if (auto lemma = mk_ext_congr_lemma(R, e)) {
|
||||
list<optional<name>> const * it = &(lemma->m_rel_names);
|
||||
for (expr const & arg : args) {
|
||||
lean_assert(*it);
|
||||
if (auto R1 = head(*it)) {
|
||||
internalize_core(*R1, arg, toplevel, to_propagate);
|
||||
add_occurrence(R, e, *R1, arg);
|
||||
}
|
||||
it = &tail(*it);
|
||||
}
|
||||
if (!lemma->m_fixed_fun) {
|
||||
internalize_core(get_eq_name(), fn, false, false);
|
||||
add_occurrence(get_eq_name(), e, get_eq_name(), fn);
|
||||
}
|
||||
add_congruence_table(*lemma, e);
|
||||
}
|
||||
add_congruence_table(*lemma, e);
|
||||
}
|
||||
apply_simple_eqvs(e);
|
||||
break;
|
||||
|
|
@ -972,10 +1060,15 @@ void congruence_closure::remove_parents(name const & R, expr const & e) {
|
|||
auto ps = m_parents.find(child_key(R, e));
|
||||
if (!ps) return;
|
||||
ps->for_each([&](parent_occ const & p) {
|
||||
auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr);
|
||||
lean_assert(lemma);
|
||||
congr_key k = mk_congr_key(*lemma, p.m_expr);
|
||||
m_congruences.erase(k);
|
||||
if (g_heq_based && R == get_eq_name() && p.m_R == get_eq_name()) {
|
||||
eq_congr_key k = mk_eq_congr_key(p.m_expr);
|
||||
m_eq_congruences.erase(k);
|
||||
} else {
|
||||
auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr);
|
||||
lean_assert(lemma);
|
||||
congr_key k = mk_congr_key(*lemma, p.m_expr);
|
||||
m_congruences.erase(k);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -983,9 +1076,13 @@ void congruence_closure::reinsert_parents(name const & R, expr const & e) {
|
|||
auto ps = m_parents.find(child_key(R, e));
|
||||
if (!ps) return;
|
||||
ps->for_each([&](parent_occ const & p) {
|
||||
auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr);
|
||||
lean_assert(lemma);
|
||||
add_congruence_table(*lemma, p.m_expr);
|
||||
if (g_heq_based && R == get_eq_name() && p.m_R == get_eq_name()) {
|
||||
add_eq_congruence_table(p.m_expr);
|
||||
} else {
|
||||
auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr);
|
||||
lean_assert(lemma);
|
||||
add_congruence_table(*lemma, p.m_expr);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -1037,18 +1134,20 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
|
|||
lean_assert(r1 && r2);
|
||||
bool flipped = false;
|
||||
|
||||
// We want r2 to be the root of the combined class.
|
||||
/* We want r2 to be the root of the combined class. */
|
||||
|
||||
// We swap (e1,n1,r1) with (e2,n2,r2) when
|
||||
// 1- r1->m_interpreted && !r2->m_interpreted.
|
||||
// Reason: to decide when to propagate we check whether the root of the equivalence class
|
||||
// is true/false. So, this condition is to make sure if true/false is an equivalence class,
|
||||
// then one of them is the root. If both are, it doesn't matter, since the state is inconsistent
|
||||
// anyway.
|
||||
// 2- r1->m_constructor && !r2->m_interpreted && !r2->m_constructor
|
||||
// Reason: we want constructors to be the representative of their equivalence classes.
|
||||
// 2- r1->m_size > r2->m_size && !r2->m_interpreted && !r2->m_constructor
|
||||
// Reason: performance.
|
||||
/*
|
||||
We swap (e1,n1,r1) with (e2,n2,r2) when
|
||||
1- r1->m_interpreted && !r2->m_interpreted.
|
||||
Reason: to decide when to propagate we check whether the root of the equivalence class
|
||||
is true/false. So, this condition is to make sure if true/false is an equivalence class,
|
||||
then one of them is the root. If both are, it doesn't matter, since the state is inconsistent
|
||||
anyway.
|
||||
2- r1->m_constructor && !r2->m_interpreted && !r2->m_constructor
|
||||
Reason: we want constructors to be the representative of their equivalence classes.
|
||||
3- r1->m_size > r2->m_size && !r2->m_interpreted && !r2->m_constructor
|
||||
Reason: performance.
|
||||
*/
|
||||
if ((r1->m_interpreted && !r2->m_interpreted) ||
|
||||
(r1->m_constructor && !r2->m_interpreted && !r2->m_constructor) ||
|
||||
(r1->m_size > r2->m_size && !r2->m_interpreted && !r2->m_constructor)) {
|
||||
|
|
@ -1068,21 +1167,23 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
|
|||
expr e2_root = n2->m_root;
|
||||
entry new_n1 = *n1;
|
||||
|
||||
// Following target/proof we have
|
||||
// e1 -> ... -> r1
|
||||
// e2 -> ... -> r2
|
||||
// We want
|
||||
// r1 -> ... -> e1 -> e2 -> ... -> r2
|
||||
/*
|
||||
Following target/proof we have
|
||||
e1 -> ... -> r1
|
||||
e2 -> ... -> r2
|
||||
We want
|
||||
r1 -> ... -> e1 -> e2 -> ... -> r2
|
||||
*/
|
||||
invert_trans(R, e1);
|
||||
new_n1.m_target = e2;
|
||||
new_n1.m_proof = H;
|
||||
new_n1.m_flipped = flipped;
|
||||
m_entries.insert(eqc_key(R, e1), new_n1);
|
||||
|
||||
// The hash code for the parents is going to change
|
||||
/* The hash code for the parents is going to change */
|
||||
remove_parents(R, e1_root);
|
||||
|
||||
// force all m_root fields in e1 equivalence class to point to e2_root
|
||||
/* force all m_root fields in e1 equivalence class to point to e2_root */
|
||||
bool propagate = R == get_iff_name() && is_true_or_false(e2_root);
|
||||
buffer<expr> to_propagate;
|
||||
expr it = e1;
|
||||
|
|
@ -1296,7 +1397,74 @@ bool congruence_closure::is_eqv(name const & R, expr const & e1, expr const & e2
|
|||
return n1->m_root == n2->m_root;
|
||||
}
|
||||
|
||||
expr congruence_closure::mk_eq_congr_proof(expr const & lhs, expr const & rhs, bool heq_proofs) const {
|
||||
lean_assert(g_heq_based);
|
||||
app_builder & b = get_app_builder();
|
||||
buffer<expr> lhs_args, rhs_args;
|
||||
expr const * lhs_it = &lhs;
|
||||
expr const * rhs_it = &rhs;
|
||||
while (is_app(*lhs_it) && is_app(*rhs_it) && *lhs_it != *rhs_it) {
|
||||
lean_assert(is_eqv(get_eq_name(), *lhs_it, *rhs_it));
|
||||
lhs_args.push_back(app_arg(*lhs_it));
|
||||
rhs_args.push_back(app_arg(*rhs_it));
|
||||
lhs_it = &app_fn(*lhs_it);
|
||||
rhs_it = &app_fn(*rhs_it);
|
||||
}
|
||||
if (lhs_args.empty()) {
|
||||
if (heq_proofs)
|
||||
return b.mk_heq_refl(lhs);
|
||||
else
|
||||
return b.mk_eq_refl(lhs);
|
||||
}
|
||||
std::reverse(lhs_args.begin(), lhs_args.end());
|
||||
std::reverse(rhs_args.begin(), rhs_args.end());
|
||||
lean_assert(lhs_args.size() == rhs_args.size());
|
||||
expr const & lhs_fn = *lhs_it;
|
||||
expr const & rhs_fn = *rhs_it;
|
||||
lean_assert(is_eqv(get_eq_name(), lhs_fn, rhs_fn) || is_def_eq(lhs_fn, rhs_fn));
|
||||
lean_assert(is_def_eq(infer_type(lhs_fn), infer_type(rhs_fn)));
|
||||
/* Create proof for
|
||||
(lhs_fn lhs_args[0] ... lhs_args[n-1]) = (lhs_fn rhs_args[0] ... rhs_args[n-1])
|
||||
where
|
||||
n == lhs_args.size()
|
||||
*/
|
||||
auto spec_lemma = mk_ext_hcongr_lemma(lhs_fn, lhs_args.size());
|
||||
lean_assert(spec_lemma);
|
||||
list<congr_arg_kind> const * kinds_it = &spec_lemma->m_congr_lemma.get_arg_kinds();
|
||||
buffer<expr> lemma_args;
|
||||
for (unsigned i = 0; i < lhs_args.size(); i++) {
|
||||
lean_assert(kinds_it);
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
if (head(*kinds_it) == congr_arg_kind::HEq) {
|
||||
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
|
||||
} else {
|
||||
lean_assert(head(*kinds_it) == congr_arg_kind::Eq);
|
||||
lemma_args.push_back(*get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i]));
|
||||
}
|
||||
kinds_it = &(tail(*kinds_it));
|
||||
}
|
||||
expr r = mk_app(spec_lemma->m_congr_lemma.get_proof(), lemma_args);
|
||||
if (spec_lemma->m_heq_result && !heq_proofs)
|
||||
r = b.mk_eq_of_heq(r);
|
||||
else if (!spec_lemma->m_heq_result && heq_proofs)
|
||||
r = b.mk_heq_of_eq(r);
|
||||
if (is_def_eq(lhs_fn, rhs_fn))
|
||||
return r;
|
||||
/* Convert r into a proof of lhs = rhs using eq.rec and
|
||||
the proof that lhs_fn = rhs_fn */
|
||||
expr lhs_fn_eq_rhs_fn = *get_eqv_proof(get_eq_name(), lhs_fn, rhs_fn);
|
||||
expr x = mk_fresh_local(infer_type(lhs_fn));
|
||||
expr motive_rhs = mk_app(x, rhs_args);
|
||||
expr motive = heq_proofs ? b.mk_heq(lhs, motive_rhs) : b.mk_eq(lhs, motive_rhs);
|
||||
return b.mk_eq_rec(Fun(x, motive), r, lhs_fn_eq_rhs_fn);
|
||||
}
|
||||
|
||||
expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const {
|
||||
if (g_heq_based && (R == get_eq_name() || R == get_heq_name())) {
|
||||
/* Use general eq congruence lemmas when heterogeneous equality is enabled. */
|
||||
return mk_eq_congr_proof(lhs, rhs, heq_proofs);
|
||||
}
|
||||
app_builder & b = get_app_builder();
|
||||
buffer<expr> lhs_args, rhs_args;
|
||||
expr const & lhs_fn = get_app_args(lhs, lhs_args);
|
||||
|
|
@ -1305,92 +1473,49 @@ expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, e
|
|||
auto lemma = mk_ext_congr_lemma(R, lhs);
|
||||
lean_assert(lemma);
|
||||
if (lemma->m_fixed_fun) {
|
||||
if (g_heq_based && lemma->m_hcongr_lemma && (R == get_eq_name() || R == get_heq_name())) {
|
||||
/* Try to simplify congruence proof by consuming common prefix of lhs and rhs */
|
||||
/* This branch is an optimization, and it is not necessary */
|
||||
unsigned i = 0;
|
||||
for (; i < lhs_args.size(); i++) {
|
||||
if (!is_def_eq(lhs_args[i], rhs_args[i]))
|
||||
break;
|
||||
}
|
||||
unsigned prefix_sz = i;
|
||||
unsigned rest_sz = lhs_args.size() - prefix_sz;
|
||||
if (rest_sz == 0) {
|
||||
if (heq_proofs)
|
||||
return b.mk_heq_refl(lhs);
|
||||
else
|
||||
return b.mk_eq_refl(lhs);
|
||||
}
|
||||
expr g = lhs;
|
||||
for (unsigned i = 0; i < rest_sz; i++) g = app_fn(g);
|
||||
auto spec_lemma = mk_ext_hcongr_lemma(g, rest_sz);
|
||||
lean_assert(spec_lemma);
|
||||
list<congr_arg_kind> const * it = &spec_lemma->m_congr_lemma.get_arg_kinds();
|
||||
buffer<expr> lemma_args;
|
||||
for (unsigned i = prefix_sz; i < lhs_args.size(); i++) {
|
||||
lean_assert(it);
|
||||
/* Main case: convers user-defined congruence lemmas, and
|
||||
all automatically generated congruence lemmas */
|
||||
list<optional<name>> const * it1 = &lemma->m_rel_names;
|
||||
list<congr_arg_kind> const * it2 = &lemma->m_congr_lemma.get_arg_kinds();
|
||||
buffer<expr> lemma_args;
|
||||
for (unsigned i = 0; i < lhs_args.size(); i++) {
|
||||
lean_assert(*it1 && *it2);
|
||||
switch (head(*it2)) {
|
||||
case congr_arg_kind::HEq:
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
if (head(*it) == congr_arg_kind::HEq) {
|
||||
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
|
||||
} else {
|
||||
lean_assert(head(*it) == congr_arg_kind::Eq);
|
||||
lemma_args.push_back(*get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i]));
|
||||
}
|
||||
it = &(tail(*it));
|
||||
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
|
||||
break;
|
||||
case congr_arg_kind::Eq:
|
||||
lean_assert(head(*it1));
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i]));
|
||||
break;
|
||||
case congr_arg_kind::Fixed:
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
break;
|
||||
case congr_arg_kind::FixedNoParam:
|
||||
break;
|
||||
case congr_arg_kind::Cast:
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
break;
|
||||
}
|
||||
expr r = mk_app(spec_lemma->m_congr_lemma.get_proof(), lemma_args);
|
||||
lean_assert(g_heq_based);
|
||||
if (spec_lemma->m_heq_result && !heq_proofs)
|
||||
r = b.mk_eq_of_heq(r);
|
||||
else if (!spec_lemma->m_heq_result && heq_proofs)
|
||||
r = b.mk_heq_of_eq(r);
|
||||
return r;
|
||||
} else {
|
||||
/* Main case: convers user-defined congruence lemmas, and
|
||||
all automatically generated congruence lemmas */
|
||||
list<optional<name>> const * it1 = &lemma->m_rel_names;
|
||||
list<congr_arg_kind> const * it2 = &lemma->m_congr_lemma.get_arg_kinds();
|
||||
buffer<expr> lemma_args;
|
||||
for (unsigned i = 0; i < lhs_args.size(); i++) {
|
||||
lean_assert(*it1 && *it2);
|
||||
switch (head(*it2)) {
|
||||
case congr_arg_kind::HEq:
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
|
||||
break;
|
||||
case congr_arg_kind::Eq:
|
||||
lean_assert(head(*it1));
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i]));
|
||||
break;
|
||||
case congr_arg_kind::Fixed:
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
break;
|
||||
case congr_arg_kind::FixedNoParam:
|
||||
break;
|
||||
case congr_arg_kind::Cast:
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
break;
|
||||
}
|
||||
it1 = &(tail(*it1));
|
||||
it2 = &(tail(*it2));
|
||||
}
|
||||
expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args);
|
||||
if (lemma->m_lift_needed) {
|
||||
r = b.lift_from_eq(R, r);
|
||||
}
|
||||
if (g_heq_based) {
|
||||
if (lemma->m_heq_result && !heq_proofs)
|
||||
r = b.mk_eq_of_heq(r);
|
||||
else if (!lemma->m_heq_result && heq_proofs)
|
||||
r = b.mk_heq_of_eq(r);
|
||||
}
|
||||
return r;
|
||||
it1 = &(tail(*it1));
|
||||
it2 = &(tail(*it2));
|
||||
}
|
||||
expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args);
|
||||
if (lemma->m_lift_needed) {
|
||||
r = b.lift_from_eq(R, r);
|
||||
}
|
||||
if (g_heq_based) {
|
||||
if (lemma->m_heq_result && !heq_proofs)
|
||||
r = b.mk_eq_of_heq(r);
|
||||
else if (!lemma->m_heq_result && heq_proofs)
|
||||
r = b.mk_heq_of_eq(r);
|
||||
}
|
||||
return r;
|
||||
} else {
|
||||
/* This branch builds congruence proofs that handle equality between functions.
|
||||
The proof is created using congr_arg/congr_fun/congr lemmas.
|
||||
|
|
|
|||
|
|
@ -76,7 +76,22 @@ class congruence_closure {
|
|||
}
|
||||
};
|
||||
|
||||
/* Key for the congruence set */
|
||||
/* Key for the equality congruence table.
|
||||
|
||||
\remark We only use the equality congruence table when the support for heterogeneous equality
|
||||
is turned on (see blast.cc.heq option). Otherwise, we store equality congruences in the
|
||||
generic congruence table and rely on automatically generated congruence lemmas for
|
||||
weakly dependent functions. */
|
||||
struct eq_congr_key {
|
||||
expr m_expr;
|
||||
unsigned m_hash;
|
||||
};
|
||||
|
||||
struct eq_congr_key_cmp {
|
||||
int operator()(eq_congr_key const & k1, eq_congr_key const & k2) const;
|
||||
};
|
||||
|
||||
/* Key for the congruence table. */
|
||||
struct congr_key {
|
||||
name m_R;
|
||||
expr m_expr;
|
||||
|
|
@ -93,6 +108,9 @@ class congruence_closure {
|
|||
congr_key() { m_eq = 0; m_iff = 0; m_symm_rel = 0; }
|
||||
};
|
||||
|
||||
int cmp_eq_iff_keys(congr_key const & k1, congr_key const & k2) const;
|
||||
int cmp_symm_rel_keys(congr_key const & k1, congr_key const & k2) const;
|
||||
|
||||
struct congr_key_cmp {
|
||||
int operator()(congr_key const & k1, congr_key const & k2) const;
|
||||
};
|
||||
|
|
@ -105,10 +123,12 @@ class congruence_closure {
|
|||
typedef eqc_key_cmp parent_occ_cmp;
|
||||
typedef rb_tree<parent_occ, parent_occ_cmp> parent_occ_set;
|
||||
typedef rb_map<child_key, parent_occ_set, child_key_cmp> parents;
|
||||
typedef rb_tree<eq_congr_key, eq_congr_key_cmp> eq_congruences;
|
||||
typedef rb_tree<congr_key, congr_key_cmp> congruences;
|
||||
typedef rb_map<expr, expr, expr_quick_cmp> subsingleton_reprs;
|
||||
entries m_entries;
|
||||
parents m_parents;
|
||||
eq_congruences m_eq_congruences;
|
||||
congruences m_congruences;
|
||||
/** The following mapping store a representative for each subsingleton type */
|
||||
subsingleton_reprs m_subsingleton_reprs;
|
||||
|
|
@ -129,6 +149,7 @@ class congruence_closure {
|
|||
|
||||
int compare_symm(name const & R, expr lhs1, expr rhs1, expr lhs2, expr rhs2) const;
|
||||
int compare_root(name const & R, expr e1, expr e2) const;
|
||||
eq_congr_key mk_eq_congr_key(expr const & e) const;
|
||||
unsigned symm_hash(name const & R, expr const & lhs, expr const & rhs) const;
|
||||
congr_key mk_congr_key(ext_congr_lemma const & lemma, expr const & e) const;
|
||||
void check_iff_true(congr_key const & k);
|
||||
|
|
@ -140,6 +161,7 @@ class congruence_closure {
|
|||
void mk_entry_core(name const & R, expr const & e, bool to_propagate);
|
||||
void mk_entry(name const & R, expr const & e, bool to_propagate);
|
||||
void add_occurrence(name const & Rp, expr const & parent, name const & Rc, expr const & child);
|
||||
void add_eq_congruence_table(expr const & e);
|
||||
void add_congruence_table(ext_congr_lemma const & lemma, expr const & e);
|
||||
void invert_trans(name const & R, expr const & e, bool new_flipped, optional<expr> new_target, optional<expr> new_proof);
|
||||
void invert_trans(name const & R, expr const & e);
|
||||
|
|
@ -152,6 +174,7 @@ class congruence_closure {
|
|||
void add_eqv_core(name const & R, expr const & lhs, expr const & rhs, expr const & H, optional<expr> const & added_prop, bool heq_proof);
|
||||
void propagate_no_confusion_eq(expr const & e1, expr const & e2);
|
||||
|
||||
expr mk_eq_congr_proof(expr const & lhs, expr const & rhs, bool heq_proofs) const;
|
||||
expr mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const;
|
||||
expr mk_congr_proof(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const;
|
||||
expr mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H, bool heq_proofs) const;
|
||||
|
|
|
|||
8
tests/lean/run/blast_cc_heq9.lean
Normal file
8
tests/lean/run/blast_cc_heq9.lean
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
example (f g : Π {A : Type₁}, A → A → A) (h : nat → nat) (a b : nat) :
|
||||
h = f a → h b = f a b :=
|
||||
by blast
|
||||
|
||||
|
||||
example (f g : Π {A : Type₁} (a b : A), A) (h : nat → nat) (a b : nat) :
|
||||
h = f a → h b = f a b :=
|
||||
by blast
|
||||
Loading…
Add table
Reference in a new issue