feat(library/blast/congruence_closure): more general congruence lemmas

This commit is contained in:
Leonardo de Moura 2016-01-13 21:44:32 -08:00
parent d9294fc164
commit 3f7122ce07
3 changed files with 348 additions and 192 deletions

View file

@ -613,6 +613,70 @@ int congruence_closure::compare_root(name const & R, expr e1, expr e2) const {
return expr_quick_cmp()(e1, e2);
}
int congruence_closure::eq_congr_key_cmp::operator()(eq_congr_key const & k1, eq_congr_key const & k2) const {
lean_assert(g_heq_based);
if (k1.m_hash != k2.m_hash)
return unsigned_cmp()(k1.m_hash, k2.m_hash);
lean_assert(is_app(k1.m_expr) && is_app(k2.m_expr));
expr const & f1 = app_fn(k1.m_expr);
expr const & a1 = app_arg(k1.m_expr);
expr const & f2 = app_fn(k2.m_expr);
expr const & a2 = app_arg(k2.m_expr);
int r = g_cc->compare_root(get_eq_name(), a1, a2);
if (r != 0) return r;
r = g_cc->compare_root(get_eq_name(), f1, f2);
if (r != 0) return r;
if (is_app(f1) && is_app(f2)) return 0; /* composite */
if (f1 == f2) return 0; /* identical functions */
expr f1_type = infer_type(f1);
expr f2_type = infer_type(f2);
if (is_def_eq(f1_type, f2_type)) return 0; /* same function type */
/*
f1 and f2 are not applications and have different types.
We can't generate a congruence proof in this case because the following lemma
hcongr : f1 == f2 -> a1 == a2 -> f1 a1 == f2 a2
is not provable. Remark: it is also not provable in MLTT, Coq and Agda (even if we assume UIP).
*/
return expr_quick_cmp()(f1, f2);
}
/* \brief Create a equality congruence table key.
\remark This table and key are only used when heterogeneous equality support is enabled. */
auto congruence_closure::mk_eq_congr_key(expr const & e) const -> eq_congr_key {
lean_assert(is_app(e));
eq_congr_key k;
k.m_expr = e;
expr const & f = app_fn(e);
expr const & a = app_arg(e);
unsigned h = hash(get_root(get_eq_name(), f).hash(), get_root(get_eq_name(), a).hash());
k.m_hash = h;
return k;
}
int congruence_closure::cmp_eq_iff_keys(congr_key const & k1, congr_key const & k2) const {
lean_assert(k1.m_eq == k2.m_eq);
lean_assert(k1.m_iff == k2.m_iff);
lean_assert(k1.m_eq != k1.m_iff);
name const & R = k1.m_eq ? get_eq_name() : get_iff_name();
expr const & lhs1 = app_arg(app_fn(k1.m_expr));
expr const & rhs1 = app_arg(k1.m_expr);
expr const & lhs2 = app_arg(app_fn(k2.m_expr));
expr const & rhs2 = app_arg(k2.m_expr);
return compare_symm(R, lhs1, rhs1, lhs2, rhs2);
}
int congruence_closure::cmp_symm_rel_keys(congr_key const & k1, congr_key const & k2) const {
name R1, R2;
expr lhs1, rhs1, lhs2, rhs2;
lean_verify(is_equivalence_relation_app(k1.m_expr, R1, lhs1, rhs1));
lean_verify(is_equivalence_relation_app(k2.m_expr, R2, lhs2, rhs2));
if (R1 != R2)
return quick_cmp(R1, R2);
return compare_symm(R1, lhs1, rhs1, lhs2, rhs2);
}
int congruence_closure::congr_key_cmp::operator()(congr_key const & k1, congr_key const & k2) const {
if (k1.m_hash != k2.m_hash)
return unsigned_cmp()(k1.m_hash, k2.m_hash);
@ -624,67 +688,55 @@ int congruence_closure::congr_key_cmp::operator()(congr_key const & k1, congr_ke
return k1.m_iff ? -1 : 1;
if (k2.m_symm_rel != k2.m_symm_rel)
return k1.m_symm_rel ? -1 : 1;
if (k1.m_eq || k1.m_iff) {
name const & R = k1.m_eq ? get_eq_name() : get_iff_name();
expr const & lhs1 = app_arg(app_fn(k1.m_expr));
expr const & rhs1 = app_arg(k1.m_expr);
expr const & lhs2 = app_arg(app_fn(k2.m_expr));
expr const & rhs2 = app_arg(k2.m_expr);
return g_cc->compare_symm(R, lhs1, rhs1, lhs2, rhs2);
} else if (k1.m_symm_rel) {
name R1, R2;
expr lhs1, rhs1, lhs2, rhs2;
lean_verify(is_equivalence_relation_app(k1.m_expr, R1, lhs1, rhs1));
lean_verify(is_equivalence_relation_app(k2.m_expr, R2, lhs2, rhs2));
if (R1 != R2)
return quick_cmp(R1, R2);
return g_cc->compare_symm(R1, lhs1, rhs1, lhs2, rhs2);
} else {
lean_assert(!k1.m_eq && !k2.m_eq && !k1.m_iff && !k2.m_iff &&
!k1.m_symm_rel && !k2.m_symm_rel);
lean_assert(k1.m_R == k2.m_R);
buffer<expr> args1, args2;
expr const & fn1 = get_app_args(k1.m_expr, args1);
expr const & fn2 = get_app_args(k2.m_expr, args2);
if (args1.size() != args2.size())
return unsigned_cmp()(args1.size(), args2.size());
auto lemma = mk_ext_congr_lemma(k1.m_R, k1.m_expr);
lean_assert(lemma);
if (!lemma->m_fixed_fun) {
int r = g_cc->compare_root(get_eq_name(), fn1, fn2);
if (k1.m_eq || k1.m_iff)
return g_cc->cmp_eq_iff_keys(k1, k2);
if (k1.m_symm_rel)
return g_cc->cmp_symm_rel_keys(k1, k2);
lean_assert(!k1.m_eq && !k2.m_eq && !k1.m_iff && !k2.m_iff &&
!k1.m_symm_rel && !k2.m_symm_rel);
lean_assert(k1.m_R == k2.m_R);
buffer<expr> args1, args2;
expr const & fn1 = get_app_args(k1.m_expr, args1);
expr const & fn2 = get_app_args(k2.m_expr, args2);
if (args1.size() != args2.size())
return unsigned_cmp()(args1.size(), args2.size());
auto lemma = mk_ext_congr_lemma(k1.m_R, k1.m_expr);
lean_assert(lemma);
if (!lemma->m_fixed_fun) {
int r = g_cc->compare_root(get_eq_name(), fn1, fn2);
if (r != 0) return r;
for (unsigned i = 0; i < args1.size(); i++) {
r = g_cc->compare_root(get_eq_name(), args1[i], args2[i]);
if (r != 0) return r;
for (unsigned i = 0; i < args1.size(); i++) {
r = g_cc->compare_root(get_eq_name(), args1[i], args2[i]);
if (r != 0) return r;
}
return 0;
} else {
list<optional<name>> const * it1 = &lemma->m_rel_names;
list<congr_arg_kind> const * it2 = &lemma->m_congr_lemma.get_arg_kinds();
int r;
for (unsigned i = 0; i < args1.size(); i++) {
lean_assert(*it1); lean_assert(*it2);
switch (head(*it2)) {
case congr_arg_kind::HEq:
case congr_arg_kind::Eq:
lean_assert(head(*it1));
r = g_cc->compare_root(*head(*it1), args1[i], args2[i]);
if (r != 0) return r;
break;
case congr_arg_kind::Fixed:
case congr_arg_kind::FixedNoParam:
r = expr_quick_cmp()(args1[i], args2[i]);
if (r != 0) return r;
break;
case congr_arg_kind::Cast:
// do nothing... ignore argument
break;
}
it1 = &(tail(*it1));
it2 = &(tail(*it2));
}
return 0;
}
return 0;
} else {
list<optional<name>> const * it1 = &lemma->m_rel_names;
list<congr_arg_kind> const * it2 = &lemma->m_congr_lemma.get_arg_kinds();
int r;
for (unsigned i = 0; i < args1.size(); i++) {
lean_assert(*it1); lean_assert(*it2);
switch (head(*it2)) {
case congr_arg_kind::HEq:
case congr_arg_kind::Eq:
lean_assert(head(*it1));
r = g_cc->compare_root(*head(*it1), args1[i], args2[i]);
if (r != 0) return r;
break;
case congr_arg_kind::Fixed:
case congr_arg_kind::FixedNoParam:
r = expr_quick_cmp()(args1[i], args2[i]);
if (r != 0) return r;
break;
case congr_arg_kind::Cast:
// do nothing... ignore argument
break;
}
it1 = &(tail(*it1));
it2 = &(tail(*it2));
}
return 0;
}
}
@ -773,17 +825,41 @@ void congruence_closure::check_iff_true(congr_key const & k) {
push_todo(get_iff_name(), e, mk_true(), *g_iff_true_mark, heq_proof);
}
void congruence_closure::add_eq_congruence_table(expr const & e) {
lean_assert(is_app(e));
lean_assert(g_heq_based);
eq_congr_key k = mk_eq_congr_key(e);
if (auto old_k = m_eq_congruences.find(k)) {
/*
Found new equivalence: e ~ old_k->m_expr
1. Update m_cg_root field for e
*/
eqc_key k(get_eq_name(), e);
entry new_entry = *m_entries.find(k);
new_entry.m_cg_root = old_k->m_expr;
m_entries.insert(k, new_entry);
/* 2. Put new equivalence in the TODO queue */
/* TODO(Leo): check if the following line is a bottleneck */
bool heq_proof = !is_def_eq(infer_type(e), infer_type(old_k->m_expr));
push_todo(get_eq_name(), e, old_k->m_expr, *g_congr_mark, heq_proof);
} else {
m_eq_congruences.insert(k);
}
}
void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, expr const & e) {
lean_assert(is_app(e));
congr_key k = mk_congr_key(lemma, e);
if (auto old_k = m_congruences.find(k)) {
// Found new equivalence: e ~ old_k->m_expr
// 1. Update m_cg_root field for e
/*
Found new equivalence: e ~ old_k->m_expr
1. Update m_cg_root field for e
*/
eqc_key k(lemma.m_R, e);
entry new_entry = *m_entries.find(k);
new_entry.m_cg_root = old_k->m_expr;
m_entries.insert(k, new_entry);
// 2. Put new equivalence in the TODO queue
/* 2. Put new equivalence in the TODO queue */
bool heq_proof = false;
if (lemma.m_heq_result) {
lean_assert(g_heq_based);
@ -893,32 +969,44 @@ void congruence_closure::internalize_core(name R, expr const & e, bool toplevel,
case expr_kind::App: {
bool is_lapp = is_logical_app(e);
mk_entry_core(R, e, to_propagate && !is_lapp);
buffer<expr> args;
expr const & fn = get_app_args(e, args);
if (toplevel) {
if (is_lapp) {
to_propagate = true; // we must propagate the children of a top-level logical app (or, and, iff, ite)
} else {
toplevel = false; // children of a non-logical application will not be marked as toplevel
}
if (R == get_eq_name() && g_heq_based) {
bool toplevel = false;
bool to_propagate = false;
internalize_core(R, app_fn(e), toplevel, to_propagate);
internalize_core(R, app_arg(e), toplevel, to_propagate);
add_occurrence(R, e, R, app_fn(e));
add_occurrence(R, e, R, app_arg(e));
add_eq_congruence_table(e);
} else {
to_propagate = false;
}
if (auto lemma = mk_ext_congr_lemma(R, e)) {
list<optional<name>> const * it = &(lemma->m_rel_names);
for (expr const & arg : args) {
lean_assert(*it);
if (auto R1 = head(*it)) {
internalize_core(*R1, arg, toplevel, to_propagate);
add_occurrence(R, e, *R1, arg);
/* Handle user-defined congruence lemmas, congruence lemmas for other relations,
and automatically generated lemmas for weakly-dependent-functions and relations. */
buffer<expr> args;
expr const & fn = get_app_args(e, args);
if (toplevel) {
if (is_lapp) {
to_propagate = true; // we must propagate the children of a top-level logical app (or, and, iff, ite)
} else {
toplevel = false; // children of a non-logical application will not be marked as toplevel
}
it = &tail(*it);
} else {
to_propagate = false;
}
if (!lemma->m_fixed_fun) {
internalize_core(get_eq_name(), fn, false, false);
add_occurrence(get_eq_name(), e, get_eq_name(), fn);
if (auto lemma = mk_ext_congr_lemma(R, e)) {
list<optional<name>> const * it = &(lemma->m_rel_names);
for (expr const & arg : args) {
lean_assert(*it);
if (auto R1 = head(*it)) {
internalize_core(*R1, arg, toplevel, to_propagate);
add_occurrence(R, e, *R1, arg);
}
it = &tail(*it);
}
if (!lemma->m_fixed_fun) {
internalize_core(get_eq_name(), fn, false, false);
add_occurrence(get_eq_name(), e, get_eq_name(), fn);
}
add_congruence_table(*lemma, e);
}
add_congruence_table(*lemma, e);
}
apply_simple_eqvs(e);
break;
@ -972,10 +1060,15 @@ void congruence_closure::remove_parents(name const & R, expr const & e) {
auto ps = m_parents.find(child_key(R, e));
if (!ps) return;
ps->for_each([&](parent_occ const & p) {
auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr);
lean_assert(lemma);
congr_key k = mk_congr_key(*lemma, p.m_expr);
m_congruences.erase(k);
if (g_heq_based && R == get_eq_name() && p.m_R == get_eq_name()) {
eq_congr_key k = mk_eq_congr_key(p.m_expr);
m_eq_congruences.erase(k);
} else {
auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr);
lean_assert(lemma);
congr_key k = mk_congr_key(*lemma, p.m_expr);
m_congruences.erase(k);
}
});
}
@ -983,9 +1076,13 @@ void congruence_closure::reinsert_parents(name const & R, expr const & e) {
auto ps = m_parents.find(child_key(R, e));
if (!ps) return;
ps->for_each([&](parent_occ const & p) {
auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr);
lean_assert(lemma);
add_congruence_table(*lemma, p.m_expr);
if (g_heq_based && R == get_eq_name() && p.m_R == get_eq_name()) {
add_eq_congruence_table(p.m_expr);
} else {
auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr);
lean_assert(lemma);
add_congruence_table(*lemma, p.m_expr);
}
});
}
@ -1037,18 +1134,20 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
lean_assert(r1 && r2);
bool flipped = false;
// We want r2 to be the root of the combined class.
/* We want r2 to be the root of the combined class. */
// We swap (e1,n1,r1) with (e2,n2,r2) when
// 1- r1->m_interpreted && !r2->m_interpreted.
// Reason: to decide when to propagate we check whether the root of the equivalence class
// is true/false. So, this condition is to make sure if true/false is an equivalence class,
// then one of them is the root. If both are, it doesn't matter, since the state is inconsistent
// anyway.
// 2- r1->m_constructor && !r2->m_interpreted && !r2->m_constructor
// Reason: we want constructors to be the representative of their equivalence classes.
// 2- r1->m_size > r2->m_size && !r2->m_interpreted && !r2->m_constructor
// Reason: performance.
/*
We swap (e1,n1,r1) with (e2,n2,r2) when
1- r1->m_interpreted && !r2->m_interpreted.
Reason: to decide when to propagate we check whether the root of the equivalence class
is true/false. So, this condition is to make sure if true/false is an equivalence class,
then one of them is the root. If both are, it doesn't matter, since the state is inconsistent
anyway.
2- r1->m_constructor && !r2->m_interpreted && !r2->m_constructor
Reason: we want constructors to be the representative of their equivalence classes.
3- r1->m_size > r2->m_size && !r2->m_interpreted && !r2->m_constructor
Reason: performance.
*/
if ((r1->m_interpreted && !r2->m_interpreted) ||
(r1->m_constructor && !r2->m_interpreted && !r2->m_constructor) ||
(r1->m_size > r2->m_size && !r2->m_interpreted && !r2->m_constructor)) {
@ -1068,21 +1167,23 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
expr e2_root = n2->m_root;
entry new_n1 = *n1;
// Following target/proof we have
// e1 -> ... -> r1
// e2 -> ... -> r2
// We want
// r1 -> ... -> e1 -> e2 -> ... -> r2
/*
Following target/proof we have
e1 -> ... -> r1
e2 -> ... -> r2
We want
r1 -> ... -> e1 -> e2 -> ... -> r2
*/
invert_trans(R, e1);
new_n1.m_target = e2;
new_n1.m_proof = H;
new_n1.m_flipped = flipped;
m_entries.insert(eqc_key(R, e1), new_n1);
// The hash code for the parents is going to change
/* The hash code for the parents is going to change */
remove_parents(R, e1_root);
// force all m_root fields in e1 equivalence class to point to e2_root
/* force all m_root fields in e1 equivalence class to point to e2_root */
bool propagate = R == get_iff_name() && is_true_or_false(e2_root);
buffer<expr> to_propagate;
expr it = e1;
@ -1296,7 +1397,74 @@ bool congruence_closure::is_eqv(name const & R, expr const & e1, expr const & e2
return n1->m_root == n2->m_root;
}
expr congruence_closure::mk_eq_congr_proof(expr const & lhs, expr const & rhs, bool heq_proofs) const {
lean_assert(g_heq_based);
app_builder & b = get_app_builder();
buffer<expr> lhs_args, rhs_args;
expr const * lhs_it = &lhs;
expr const * rhs_it = &rhs;
while (is_app(*lhs_it) && is_app(*rhs_it) && *lhs_it != *rhs_it) {
lean_assert(is_eqv(get_eq_name(), *lhs_it, *rhs_it));
lhs_args.push_back(app_arg(*lhs_it));
rhs_args.push_back(app_arg(*rhs_it));
lhs_it = &app_fn(*lhs_it);
rhs_it = &app_fn(*rhs_it);
}
if (lhs_args.empty()) {
if (heq_proofs)
return b.mk_heq_refl(lhs);
else
return b.mk_eq_refl(lhs);
}
std::reverse(lhs_args.begin(), lhs_args.end());
std::reverse(rhs_args.begin(), rhs_args.end());
lean_assert(lhs_args.size() == rhs_args.size());
expr const & lhs_fn = *lhs_it;
expr const & rhs_fn = *rhs_it;
lean_assert(is_eqv(get_eq_name(), lhs_fn, rhs_fn) || is_def_eq(lhs_fn, rhs_fn));
lean_assert(is_def_eq(infer_type(lhs_fn), infer_type(rhs_fn)));
/* Create proof for
(lhs_fn lhs_args[0] ... lhs_args[n-1]) = (lhs_fn rhs_args[0] ... rhs_args[n-1])
where
n == lhs_args.size()
*/
auto spec_lemma = mk_ext_hcongr_lemma(lhs_fn, lhs_args.size());
lean_assert(spec_lemma);
list<congr_arg_kind> const * kinds_it = &spec_lemma->m_congr_lemma.get_arg_kinds();
buffer<expr> lemma_args;
for (unsigned i = 0; i < lhs_args.size(); i++) {
lean_assert(kinds_it);
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
if (head(*kinds_it) == congr_arg_kind::HEq) {
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
} else {
lean_assert(head(*kinds_it) == congr_arg_kind::Eq);
lemma_args.push_back(*get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i]));
}
kinds_it = &(tail(*kinds_it));
}
expr r = mk_app(spec_lemma->m_congr_lemma.get_proof(), lemma_args);
if (spec_lemma->m_heq_result && !heq_proofs)
r = b.mk_eq_of_heq(r);
else if (!spec_lemma->m_heq_result && heq_proofs)
r = b.mk_heq_of_eq(r);
if (is_def_eq(lhs_fn, rhs_fn))
return r;
/* Convert r into a proof of lhs = rhs using eq.rec and
the proof that lhs_fn = rhs_fn */
expr lhs_fn_eq_rhs_fn = *get_eqv_proof(get_eq_name(), lhs_fn, rhs_fn);
expr x = mk_fresh_local(infer_type(lhs_fn));
expr motive_rhs = mk_app(x, rhs_args);
expr motive = heq_proofs ? b.mk_heq(lhs, motive_rhs) : b.mk_eq(lhs, motive_rhs);
return b.mk_eq_rec(Fun(x, motive), r, lhs_fn_eq_rhs_fn);
}
expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const {
if (g_heq_based && (R == get_eq_name() || R == get_heq_name())) {
/* Use general eq congruence lemmas when heterogeneous equality is enabled. */
return mk_eq_congr_proof(lhs, rhs, heq_proofs);
}
app_builder & b = get_app_builder();
buffer<expr> lhs_args, rhs_args;
expr const & lhs_fn = get_app_args(lhs, lhs_args);
@ -1305,92 +1473,49 @@ expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, e
auto lemma = mk_ext_congr_lemma(R, lhs);
lean_assert(lemma);
if (lemma->m_fixed_fun) {
if (g_heq_based && lemma->m_hcongr_lemma && (R == get_eq_name() || R == get_heq_name())) {
/* Try to simplify congruence proof by consuming common prefix of lhs and rhs */
/* This branch is an optimization, and it is not necessary */
unsigned i = 0;
for (; i < lhs_args.size(); i++) {
if (!is_def_eq(lhs_args[i], rhs_args[i]))
break;
}
unsigned prefix_sz = i;
unsigned rest_sz = lhs_args.size() - prefix_sz;
if (rest_sz == 0) {
if (heq_proofs)
return b.mk_heq_refl(lhs);
else
return b.mk_eq_refl(lhs);
}
expr g = lhs;
for (unsigned i = 0; i < rest_sz; i++) g = app_fn(g);
auto spec_lemma = mk_ext_hcongr_lemma(g, rest_sz);
lean_assert(spec_lemma);
list<congr_arg_kind> const * it = &spec_lemma->m_congr_lemma.get_arg_kinds();
buffer<expr> lemma_args;
for (unsigned i = prefix_sz; i < lhs_args.size(); i++) {
lean_assert(it);
/* Main case: convers user-defined congruence lemmas, and
all automatically generated congruence lemmas */
list<optional<name>> const * it1 = &lemma->m_rel_names;
list<congr_arg_kind> const * it2 = &lemma->m_congr_lemma.get_arg_kinds();
buffer<expr> lemma_args;
for (unsigned i = 0; i < lhs_args.size(); i++) {
lean_assert(*it1 && *it2);
switch (head(*it2)) {
case congr_arg_kind::HEq:
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
if (head(*it) == congr_arg_kind::HEq) {
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
} else {
lean_assert(head(*it) == congr_arg_kind::Eq);
lemma_args.push_back(*get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i]));
}
it = &(tail(*it));
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
break;
case congr_arg_kind::Eq:
lean_assert(head(*it1));
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i]));
break;
case congr_arg_kind::Fixed:
lemma_args.push_back(lhs_args[i]);
break;
case congr_arg_kind::FixedNoParam:
break;
case congr_arg_kind::Cast:
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
break;
}
expr r = mk_app(spec_lemma->m_congr_lemma.get_proof(), lemma_args);
lean_assert(g_heq_based);
if (spec_lemma->m_heq_result && !heq_proofs)
r = b.mk_eq_of_heq(r);
else if (!spec_lemma->m_heq_result && heq_proofs)
r = b.mk_heq_of_eq(r);
return r;
} else {
/* Main case: convers user-defined congruence lemmas, and
all automatically generated congruence lemmas */
list<optional<name>> const * it1 = &lemma->m_rel_names;
list<congr_arg_kind> const * it2 = &lemma->m_congr_lemma.get_arg_kinds();
buffer<expr> lemma_args;
for (unsigned i = 0; i < lhs_args.size(); i++) {
lean_assert(*it1 && *it2);
switch (head(*it2)) {
case congr_arg_kind::HEq:
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
break;
case congr_arg_kind::Eq:
lean_assert(head(*it1));
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i]));
break;
case congr_arg_kind::Fixed:
lemma_args.push_back(lhs_args[i]);
break;
case congr_arg_kind::FixedNoParam:
break;
case congr_arg_kind::Cast:
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
break;
}
it1 = &(tail(*it1));
it2 = &(tail(*it2));
}
expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args);
if (lemma->m_lift_needed) {
r = b.lift_from_eq(R, r);
}
if (g_heq_based) {
if (lemma->m_heq_result && !heq_proofs)
r = b.mk_eq_of_heq(r);
else if (!lemma->m_heq_result && heq_proofs)
r = b.mk_heq_of_eq(r);
}
return r;
it1 = &(tail(*it1));
it2 = &(tail(*it2));
}
expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args);
if (lemma->m_lift_needed) {
r = b.lift_from_eq(R, r);
}
if (g_heq_based) {
if (lemma->m_heq_result && !heq_proofs)
r = b.mk_eq_of_heq(r);
else if (!lemma->m_heq_result && heq_proofs)
r = b.mk_heq_of_eq(r);
}
return r;
} else {
/* This branch builds congruence proofs that handle equality between functions.
The proof is created using congr_arg/congr_fun/congr lemmas.

View file

@ -76,7 +76,22 @@ class congruence_closure {
}
};
/* Key for the congruence set */
/* Key for the equality congruence table.
\remark We only use the equality congruence table when the support for heterogeneous equality
is turned on (see blast.cc.heq option). Otherwise, we store equality congruences in the
generic congruence table and rely on automatically generated congruence lemmas for
weakly dependent functions. */
struct eq_congr_key {
expr m_expr;
unsigned m_hash;
};
struct eq_congr_key_cmp {
int operator()(eq_congr_key const & k1, eq_congr_key const & k2) const;
};
/* Key for the congruence table. */
struct congr_key {
name m_R;
expr m_expr;
@ -93,6 +108,9 @@ class congruence_closure {
congr_key() { m_eq = 0; m_iff = 0; m_symm_rel = 0; }
};
int cmp_eq_iff_keys(congr_key const & k1, congr_key const & k2) const;
int cmp_symm_rel_keys(congr_key const & k1, congr_key const & k2) const;
struct congr_key_cmp {
int operator()(congr_key const & k1, congr_key const & k2) const;
};
@ -105,10 +123,12 @@ class congruence_closure {
typedef eqc_key_cmp parent_occ_cmp;
typedef rb_tree<parent_occ, parent_occ_cmp> parent_occ_set;
typedef rb_map<child_key, parent_occ_set, child_key_cmp> parents;
typedef rb_tree<eq_congr_key, eq_congr_key_cmp> eq_congruences;
typedef rb_tree<congr_key, congr_key_cmp> congruences;
typedef rb_map<expr, expr, expr_quick_cmp> subsingleton_reprs;
entries m_entries;
parents m_parents;
eq_congruences m_eq_congruences;
congruences m_congruences;
/** The following mapping store a representative for each subsingleton type */
subsingleton_reprs m_subsingleton_reprs;
@ -129,6 +149,7 @@ class congruence_closure {
int compare_symm(name const & R, expr lhs1, expr rhs1, expr lhs2, expr rhs2) const;
int compare_root(name const & R, expr e1, expr e2) const;
eq_congr_key mk_eq_congr_key(expr const & e) const;
unsigned symm_hash(name const & R, expr const & lhs, expr const & rhs) const;
congr_key mk_congr_key(ext_congr_lemma const & lemma, expr const & e) const;
void check_iff_true(congr_key const & k);
@ -140,6 +161,7 @@ class congruence_closure {
void mk_entry_core(name const & R, expr const & e, bool to_propagate);
void mk_entry(name const & R, expr const & e, bool to_propagate);
void add_occurrence(name const & Rp, expr const & parent, name const & Rc, expr const & child);
void add_eq_congruence_table(expr const & e);
void add_congruence_table(ext_congr_lemma const & lemma, expr const & e);
void invert_trans(name const & R, expr const & e, bool new_flipped, optional<expr> new_target, optional<expr> new_proof);
void invert_trans(name const & R, expr const & e);
@ -152,6 +174,7 @@ class congruence_closure {
void add_eqv_core(name const & R, expr const & lhs, expr const & rhs, expr const & H, optional<expr> const & added_prop, bool heq_proof);
void propagate_no_confusion_eq(expr const & e1, expr const & e2);
expr mk_eq_congr_proof(expr const & lhs, expr const & rhs, bool heq_proofs) const;
expr mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const;
expr mk_congr_proof(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const;
expr mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H, bool heq_proofs) const;

View file

@ -0,0 +1,8 @@
example (f g : Π {A : Type₁}, A → A → A) (h : nat → nat) (a b : nat) :
h = f a → h b = f a b :=
by blast
example (f g : Π {A : Type₁} (a b : A), A) (h : nat → nat) (a b : nat) :
h = f a → h b = f a b :=
by blast