feat(library/tactic/congruence/theory_ac): add internalization, interface with congruence closure module, and trivial/simp/orient transitions

Still missing: superpose, collapse and compose transitions.
This commit is contained in:
Leonardo de Moura 2016-12-28 21:24:46 -08:00
parent bb37b33237
commit 9065cf0350
13 changed files with 872 additions and 55 deletions

View file

@ -383,6 +383,10 @@ public:
unsigned hash() const { return m_ptr->hash(); }
void write(serializer & s) const { return m_ptr->write(s); }
macro_definition_cell const * raw() const { return m_ptr; }
friend bool is_eqp(macro_definition const & d1, macro_definition const & d2) {
return d1.m_ptr == d2.m_ptr;
}
};
/** \brief Macro attachments */

View file

@ -499,6 +499,8 @@ public:
}
}
expr mk_eq_symm(expr const & H) {
if (is_app_of(H, get_eq_refl_name()))
return H;
expr p = m_ctx.relaxed_whnf(m_ctx.infer(H));
expr A, lhs, rhs;
if (!is_eq(p, A, lhs, rhs)) {
@ -508,6 +510,11 @@ public:
level lvl = get_level(A);
return ::lean::mk_app(mk_constant(get_eq_symm_name(), {lvl}), A, lhs, rhs, H);
}
expr mk_eq_symm(expr const & a, expr const & b, expr const & H) {
expr A = m_ctx.infer(a);
level lvl = get_level(A);
return ::lean::mk_app(mk_constant(get_eq_symm_name(), {lvl}), A, a, b, H);
}
expr mk_iff_symm(expr const & H) {
expr p = m_ctx.infer(H);
expr lhs, rhs;
@ -547,6 +554,10 @@ public:
}
}
expr mk_eq_trans(expr const & H1, expr const & H2) {
if (is_app_of(H1, get_eq_refl_name()))
return H2;
if (is_app_of(H2, get_eq_refl_name()))
return H1;
expr p1 = m_ctx.relaxed_whnf(m_ctx.infer(H1));
expr p2 = m_ctx.relaxed_whnf(m_ctx.infer(H2));
expr A, lhs1, rhs1, lhs2, rhs2;
@ -559,6 +570,15 @@ public:
level lvl = get_level(A);
return ::lean::mk_app({mk_constant(get_eq_trans_name(), {lvl}), A, lhs1, rhs1, rhs2, H1, H2});
}
expr mk_eq_trans(expr const & a, expr const & b, expr const & c, expr const & H1, expr const & H2) {
if (is_app_of(H1, get_eq_refl_name()))
return H2;
if (is_app_of(H2, get_eq_refl_name()))
return H1;
expr A = m_ctx.infer(a);
level lvl = get_level(A);
return ::lean::mk_app({mk_constant(get_eq_trans_name(), {lvl}), A, a, b, c, H1, H2});
}
expr mk_iff_trans(expr const & H1, expr const & H2) {
expr p1 = m_ctx.infer(H1);
expr p2 = m_ctx.infer(H2);
@ -820,6 +840,10 @@ expr mk_eq_symm(type_context & ctx, expr const & H) {
return app_builder(ctx).mk_eq_symm(H);
}
expr mk_eq_symm(type_context & ctx, expr const & a, expr const & b, expr const & H) {
return app_builder(ctx).mk_eq_symm(a, b, H);
}
expr mk_iff_symm(type_context & ctx, expr const & H) {
return app_builder(ctx).mk_iff_symm(H);
}
@ -836,6 +860,10 @@ expr mk_eq_trans(type_context & ctx, expr const & H1, expr const & H2) {
return app_builder(ctx).mk_eq_trans(H1, H2);
}
expr mk_eq_trans(type_context & ctx, expr const & a, expr const & b, expr const & c, expr const & H1, expr const & H2) {
return app_builder(ctx).mk_eq_trans(a, b, c, H1, H2);
}
expr mk_iff_trans(type_context & ctx, expr const & H1, expr const & H2) {
return app_builder(ctx).mk_iff_trans(H1, H2);
}

View file

@ -91,12 +91,14 @@ expr mk_heq_refl(type_context & ctx, expr const & a);
/** \brief Similar a symmetry proof for the given relation */
expr mk_symm(type_context & ctx, name const & relname, expr const & H);
expr mk_eq_symm(type_context & ctx, expr const & H);
expr mk_eq_symm(type_context & ctx, expr const & a, expr const & b, expr const & H);
expr mk_iff_symm(type_context & ctx, expr const & H);
expr mk_heq_symm(type_context & ctx, expr const & H);
/** \brief Similar a transitivity proof for the given relation */
expr mk_trans(type_context & ctx, name const & relname, expr const & H1, expr const & H2);
expr mk_eq_trans(type_context & ctx, expr const & H1, expr const & H2);
expr mk_eq_trans(type_context & ctx, expr const & a, expr const & b, expr const & c, expr const & H1, expr const & H2);
expr mk_iff_trans(type_context & ctx, expr const & H1, expr const & H2);
expr mk_heq_trans(type_context & ctx, expr const & H1, expr const & H2);

View file

@ -18,6 +18,7 @@ namespace lean {
bool is_lt(expr const & a, expr const & b, bool use_hash);
/** \brief Similar to is_lt, but universe level parameter names are ignored. */
bool is_lt_no_level_params(expr const & a, expr const & b);
inline bool is_hash_lt(expr const & a, expr const & b) { return is_lt(a, b, true); }
inline bool operator<(expr const & a, expr const & b) { return is_lt(a, b, true); }
inline bool operator>(expr const & a, expr const & b) { return is_lt(b, a, true); }
inline bool operator<=(expr const & a, expr const & b) { return !is_lt(b, a, true); }

View file

@ -1,2 +1,2 @@
add_library(congruence OBJECT congruence_closure.cpp congruence_tactics.cpp
hinst_lemmas.cpp ematch.cpp theory_ac.cpp init_module.cpp)
hinst_lemmas.cpp ematch.cpp theory_ac.cpp util.cpp init_module.cpp)

View file

@ -17,6 +17,7 @@ Author: Leonardo de Moura
#include "library/app_builder.h"
#include "library/projection.h"
#include "library/constructions/constructor.h"
#include "library/tactic/congruence/util.h"
#include "library/tactic/congruence/congruence_closure.h"
namespace lean {
@ -653,15 +654,15 @@ void congruence_closure::propagate_inst_implicit(expr const & e) {
}
}
void congruence_closure::set_ac_rep(expr const & e, expr const & ac_rep) {
void congruence_closure::set_ac_var(expr const & e) {
expr e_root = get_root(e);
auto n = get_entry(e);
if (n->m_ac_rep) {
m_ac.add_eq(*n->m_ac_rep, ac_rep);
auto root_entry = get_entry(e_root);
if (root_entry->m_ac_var) {
m_ac.add_eq(*root_entry->m_ac_var, e);
} else {
entry new_entry = *n;
new_entry.m_ac_rep = some_expr(ac_rep);
m_state.m_entries.insert(e_root, new_entry);
entry new_root_entry = *root_entry;
new_root_entry.m_ac_var = some_expr(e);
m_state.m_entries.insert(e_root, new_root_entry);
}
}
@ -807,9 +808,8 @@ void congruence_closure::internalize_core(expr const & e, bool toplevel, bool to
break;
}}
if (optional<expr> ac_rep = m_ac.internalize(e, parent)) {
set_ac_rep(e, *ac_rep);
}
if (m_state.m_config.m_ac)
m_ac.internalize(e, parent);
}
/*
@ -897,19 +897,26 @@ bool congruence_closure::has_heq_proofs(expr const & root) const {
return get_entry(root)->m_heq_proofs;
}
expr congruence_closure::flip_proof_core(expr const & H, bool flipped, bool heq_proofs) const {
expr new_H = H;
if (heq_proofs && is_eq(m_ctx.relaxed_whnf(m_ctx.infer(new_H)))) {
new_H = mk_heq_of_eq(m_ctx, new_H);
}
if (!flipped) {
return new_H;
} else {
return heq_proofs ? mk_heq_symm(m_ctx, new_H) : mk_eq_symm(m_ctx, new_H);
}
}
expr congruence_closure::flip_proof(expr const & H, bool flipped, bool heq_proofs) const {
if (H == *g_congr_mark || H == *g_eq_true_mark || H == *g_refl_mark) {
return H;
} else if (is_cc_theory_proof(H)) {
expr H1 = flip_proof_core(get_cc_theory_proof_arg(H), flipped, heq_proofs);
return mark_cc_theory_proof(H1);
} else {
expr new_H = H;
if (heq_proofs && is_eq(m_ctx.relaxed_whnf(m_ctx.infer(new_H)))) {
new_H = mk_heq_of_eq(m_ctx, new_H);
}
if (!flipped) {
return new_H;
} else {
return heq_proofs ? mk_heq_symm(m_ctx, new_H) : mk_eq_symm(m_ctx, new_H);
}
return flip_proof_core(H, flipped, heq_proofs);
}
}
@ -1073,6 +1080,8 @@ expr congruence_closure::mk_proof(expr const & lhs, expr const & rhs, expr const
expr type = heq_proofs ? mk_heq(m_ctx, lhs, rhs) : mk_eq(m_ctx, lhs, rhs);
expr proof = heq_proofs ? mk_heq_refl(m_ctx, lhs) : mk_eq_refl(m_ctx, lhs);
return mk_app(mk_constant(get_id_locked_name(), {mk_level_zero()}), type, proof);
} else if (is_cc_theory_proof(H)) {
return expand_delayed_cc_proofs(*this, get_cc_theory_proof_arg(H));
} else {
return H;
}
@ -1376,10 +1385,10 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H,
new_r1.m_next = r2->m_next;
new_r2.m_next = r1->m_next;
new_r2.m_size += r1->m_size;
optional<expr> ac_rep1 = r1->m_ac_rep;
optional<expr> ac_rep2 = r2->m_ac_rep;
if (!ac_rep2)
new_r2.m_ac_rep = ac_rep1;
optional<expr> ac_var1 = r1->m_ac_var;
optional<expr> ac_var2 = r2->m_ac_var;
if (!ac_var2)
new_r2.m_ac_var = ac_var1;
if (heq_proof)
new_r2.m_heq_proofs = true;
m_state.m_entries.insert(e1_root, new_r1);
@ -1404,8 +1413,8 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H,
m_state.m_parents.insert(e2_root, ps2);
}
if (ac_rep1 && ac_rep2)
m_ac.add_eq(*ac_rep1, *ac_rep2);
if (ac_var1 && ac_var2)
m_ac.add_eq(*ac_var1, *ac_var2);
// propagate new hypotheses back to current state
if (!to_propagate.empty()) {

View file

@ -53,8 +53,8 @@ class congruence_closure {
store 'target' at 'm_target', and 'H' at 'm_proof'. Both fields are none if 'e' == m_root */
optional<expr> m_target;
optional<expr> m_proof;
/* Representative used in the AC theory */
optional<expr> m_ac_rep;
/* Variable in the AC theory. */
optional<expr> m_ac_var;
unsigned m_flipped:1; // proof has been flipped
unsigned m_to_propagate:1; // must be propagated back to state when in equivalence class containing true/false
unsigned m_interpreted:1; // true if the node should be viewed as an abstract value
@ -102,7 +102,8 @@ public:
unsigned m_ignore_instances:1;
unsigned m_values:1;
unsigned m_all_ho:1;
config() { m_ignore_instances = true; m_values = true; m_all_ho = false; }
unsigned m_ac:1;
config() { m_ignore_instances = true; m_values = true; m_all_ho = false; m_ac = true; }
};
class state {
@ -165,6 +166,8 @@ private:
refl_info_getter m_refl_info_getter;
theory_ac m_ac;
friend class theory_ac;
int compare_symm(expr lhs1, expr rhs1, expr lhs2, expr rhs2) const;
unsigned symm_hash(expr const & lhs, expr const & rhs) const;
optional<name> is_binary_relation(expr const & e, expr & lhs, expr & rhs) const;
@ -185,10 +188,11 @@ private:
void add_symm_congruence_table(expr const & e);
void mk_entry_core(expr const & e, bool to_propagate, bool interpreted = false);
void mk_entry(expr const & e, bool to_propagate);
void set_ac_rep(expr const & e, expr const & ac_rep);
void set_ac_var(expr const & e);
void internalize_app(expr const & e, bool toplevel, bool to_propagate);
void internalize_core(expr const & e, bool toplevel, bool to_propagate, optional<expr> const & parent);
void push_todo(expr const & lhs, expr const & rhs, expr const & H, bool heq_proof);
void push_new_eq(expr const & lhs, expr const & rhs, expr const & H) { push_todo(lhs, rhs, H, false); }
void push_refl_eq(expr const & lhs, expr const & rhs);
void invert_trans(expr const & e, bool new_flipped, optional<expr> new_target, optional<expr> new_proof);
void invert_trans(expr const & e);
@ -196,6 +200,7 @@ private:
void reinsert_parents(expr const & e);
void update_mt(expr const & e);
bool has_heq_proofs(expr const & root) const;
expr flip_proof_core(expr const & H, bool flipped, bool heq_proofs) const;
expr flip_proof(expr const & H, bool flipped, bool heq_proofs) const;
optional<ext_congr_lemma> mk_ext_hcongr_lemma(expr const & fn, unsigned nargs) const;
expr mk_trans(expr const & H1, expr const & H2, bool heq_proofs) const;
@ -219,7 +224,6 @@ private:
bool check_eqc(expr const & e) const;
friend ext_congr_lemma_cache_ptr const & get_cache_ptr(congruence_closure const & cc);
public:
congruence_closure(type_context & ctx, state & s);
~congruence_closure();
@ -227,6 +231,7 @@ public:
environment const & env() const { return m_ctx.env(); }
type_context & ctx() { return m_ctx; }
transparency_mode mode() const { return m_mode; }
defeq_canonizer & get_defeq_canonizer() { return m_defeq_canonizer; }
/** \brief Register expression \c e in this data-structure.
It creates entries for each sub-expression in \c e.

View file

@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "library/tactic/congruence/util.h"
#include "library/tactic/congruence/congruence_closure.h"
#include "library/tactic/congruence/congruence_tactics.h"
#include "library/tactic/congruence/hinst_lemmas.h"
@ -12,6 +13,7 @@ Author: Leonardo de Moura
namespace lean {
void initialize_congruence_module() {
initialize_congruence_util();
initialize_congruence_closure();
initialize_congruence_tactics();
initialize_hinst_lemmas();
@ -25,5 +27,6 @@ void finalize_congruence_module() {
finalize_hinst_lemmas();
finalize_congruence_tactics();
finalize_congruence_closure();
finalize_congruence_util();
}
}

View file

@ -4,11 +4,385 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <algorithm>
#include <string>
#include "library/trace.h"
#include "library/app_builder.h"
#include "library/kernel_serializer.h"
#include "library/tactic/ac_tactics.h"
#include "library/tactic/congruence/util.h"
#include "library/tactic/congruence/congruence_closure.h"
#include "library/tactic/congruence/theory_ac.h"
/* TODO(Leo): reduce after testing */
#define AC_TRUST_LEVEL 10000000
namespace lean {
enum class ac_term_kind { App };
static name * g_ac_app_name = nullptr;
static macro_definition * g_ac_app_macro = nullptr;
static std::string * g_ac_app_opcode = nullptr;
static expr expand_ac_core(expr const & e) {
unsigned nargs = macro_num_args(e);
unsigned i = nargs - 1;
expr const & op = macro_arg(e, i);
--i;
expr r = macro_arg(e, i);
while (i > 0) {
--i;
r = mk_app(op, macro_arg(e, i), r);
}
return r;
}
class ac_app_macro_cell : public macro_definition_cell {
public:
virtual name get_name() const { return *g_ac_app_name; }
virtual unsigned trust_level() const { return AC_TRUST_LEVEL; }
virtual expr check_type(expr const & e, abstract_type_context & ctx, bool) const {
return ctx.infer(macro_arg(e, 0));
}
virtual optional<expr> expand(expr const & e, abstract_type_context &) const {
return some_expr(expand_ac_core(e));
}
virtual void write(serializer & s) const {
s.write_string(*g_ac_app_opcode);
}
virtual bool operator==(macro_definition_cell const & other) const {
ac_app_macro_cell const * other_ptr = dynamic_cast<ac_app_macro_cell const *>(&other);
return other_ptr;
}
virtual unsigned hash() const { return 37; }
};
static expr mk_ac_app_core(unsigned nargs, expr const * args_op) {
lean_assert(nargs >= 3);
return mk_macro(*g_ac_app_macro, nargs, args_op);
}
static expr mk_ac_app_core(expr const & op, buffer<expr> & args) {
lean_assert(args.size() >= 2);
args.push_back(op);
expr r = mk_ac_app_core(args.size(), args.data());
args.pop_back();
return r;
}
static expr mk_ac_app(expr const & op, buffer<expr> & args) {
lean_assert(args.size() > 0);
if (args.size() == 1) {
return args[0];
} else {
std::sort(args.begin(), args.end(), is_hash_lt);
return mk_ac_app_core(op, args);
}
}
static bool is_ac_app(expr const & e) {
return is_macro(e) && is_eqp(macro_def(e), *g_ac_app_macro);
}
static expr const & get_ac_app_op(expr const & e) {
lean_assert(is_ac_app(e));
return macro_arg(e, macro_num_args(e) - 1);
}
static unsigned get_ac_app_num_args(expr const & e) {
lean_assert(is_ac_app(e));
return macro_num_args(e) - 1;
}
static expr const * get_ac_app_args(expr const & e) {
lean_assert(is_ac_app(e));
return macro_args(e);
}
/* Return true iff e1 is a "subset" of e2.
Example: The result is true for e1 := (a*a*a*b*d) and e2 := (a*a*a*a*b*b*c*d*d) */
static bool is_ac_subset(expr const & e1, expr const & e2) {
if (is_ac_app(e1)) {
if (is_ac_app(e2) && get_ac_app_op(e1) == get_ac_app_op(e2)) {
unsigned nargs1 = get_ac_app_num_args(e1);
unsigned nargs2 = get_ac_app_num_args(e2);
if (nargs1 > nargs2) return false;
expr const * args1 = get_ac_app_args(e1);
expr const * args2 = get_ac_app_args(e2);
unsigned i1 = 0;
unsigned i2 = 0;
while (i1 < nargs1 && i2 < nargs2) {
if (args1[i1] == args2[i2]) {
i1++;
i2++;
} else if (is_hash_lt(args2[i2], args1[i1])) {
i2++;
} else {
lean_assert(is_hash_lt(args1[i1], args2[i2]));
return false;
}
}
return i1 == nargs1;
} else {
return false;
}
} else {
if (is_ac_app(e2)) {
unsigned nargs2 = get_ac_app_num_args(e2);
expr const * args2 = get_ac_app_args(e2);
return std::find(args2, args2+nargs2, e1) != args2+nargs2;
} else {
return e1 == e2;
}
}
}
/* Store in r e1\e2.
Example: given e1 := (a*a*a*a*b*b*c*d*d*d) and e2 := (a*a*a*b*b*d),
the result is (a, c, d, d)
\pre is_ac_subset(e2, e1) */
static void ac_diff(expr const & e1, expr const & e2, buffer<expr> & r) {
lean_assert(is_ac_subset(e2, e1));
if (is_ac_app(e1)) {
if (is_ac_app(e2) && get_ac_app_op(e1) == get_ac_app_op(e2)) {
unsigned nargs1 = get_ac_app_num_args(e1);
unsigned nargs2 = get_ac_app_num_args(e2);
lean_assert(nargs1 >= nargs2);
expr const * args1 = get_ac_app_args(e1);
expr const * args2 = get_ac_app_args(e2);
unsigned i2 = 0;
for (unsigned i1 = 0; i1 < nargs1; i1++) {
if (i2 == nargs2) {
r.push_back(args1[i1]);
} else if (args1[i1] == args2[i2]) {
i2++;
} else {
lean_assert(is_hash_lt(args1[i1], args2[i2]));
r.push_back(args1[i1]);
}
}
} else {
bool found = false;
unsigned nargs1 = get_ac_app_num_args(e1);
expr const * args1 = get_ac_app_args(e1);
for (unsigned i = 0; i < nargs1; i++) {
if (!found && args1[i] == e2) {
found = true;
} else {
r.push_back(args1[i]);
}
}
lean_assert(found);
}
} else {
lean_assert(!is_ac_app(e1));
lean_assert(!is_ac_app(e2));
lean_assert(e1 == e2);
}
}
static void ac_append(expr const & e, buffer<expr> & r) {
if (is_ac_app(e)) {
r.append(get_ac_app_num_args(e), get_ac_app_args(e));
} else {
r.push_back(e);
}
}
/* lexdeg order */
static bool ac_lt(expr const & e1, expr const & e2) {
if (is_ac_app(e1)) {
if (is_ac_app(e2) && get_ac_app_op(e1) == get_ac_app_op(e2)) {
unsigned nargs1 = get_ac_app_num_args(e1);
unsigned nargs2 = get_ac_app_num_args(e2);
if (nargs1 < nargs2) return true;
if (nargs1 > nargs2) return false;
expr const * args1 = get_ac_app_args(e1);
expr const * args2 = get_ac_app_args(e2);
for (unsigned i = 0; i < nargs1; i++) {
if (args1[i] != args2[i])
return is_hash_lt(args1[i], args2[i]);
}
return false;
} else {
return false;
}
} else {
if (is_ac_app(e2)) {
return true;
} else {
return is_hash_lt(e1, e2);
}
}
}
static expr expand_if_ac_app(expr const & e) {
if (is_ac_app(e))
return expand_ac_core(e);
else
return e;
}
static name * g_ac_simp_name = nullptr;
static macro_definition * g_ac_simp_macro = nullptr;
static std::string * g_ac_simp_opcode = nullptr;
class ac_simp_macro_cell : public macro_definition_cell {
public:
virtual name get_name() const { return *g_ac_simp_name; }
virtual expr check_type(expr const & e, abstract_type_context & ctx, bool) const {
return mk_eq(ctx, macro_arg(e, 0), macro_arg(e, 3));
}
virtual unsigned trust_level() const { return AC_TRUST_LEVEL; }
virtual optional<expr> expand(expr const & H, abstract_type_context & ctx) const {
expr e = expand_if_ac_app(macro_arg(H, 0)); /* it is of the form t*r */
expr t = expand_if_ac_app(macro_arg(H, 1));
expr s = expand_if_ac_app(macro_arg(H, 2));
expr r = expand_if_ac_app(macro_arg(H, 3));
expr sr = expand_if_ac_app(macro_arg(H, 4));
expr t_eq_s = expand_if_ac_app(macro_arg(H, 5));
expr const & assoc = macro_arg(H, 6);
expr const & comm = macro_arg(H, 7);
if (e == sr) {
return some_expr(mk_eq_refl(ctx, e));
} else if (e == t) {
lean_assert(s == sr);
return some_expr(t_eq_s);
} else {
expr op = app_fn(app_fn(e));
expr op_r = mk_app(op, r);
expr rt = mk_app(op_r, t);
expr rs = mk_app(op, r, s);
expr rt_eq_rs = mk_congr_arg(ctx, op_r, t_eq_s);
expr e_eq_rt = perm_ac(ctx, op, assoc, comm, e, rt);
expr rs_eq_sr = perm_ac(ctx, op, assoc, comm, rs, sr);
return some_expr(mk_eq_trans(ctx, mk_eq_trans(ctx, e_eq_rt, rt_eq_rs), rs_eq_sr));
}
}
virtual void write(serializer & s) const {
s.write_string(*g_ac_simp_opcode);
}
virtual bool operator==(macro_definition_cell const & other) const {
ac_simp_macro_cell const * other_ptr = dynamic_cast<ac_simp_macro_cell const *>(&other);
return other_ptr;
}
virtual unsigned hash() const { return 31; }
};
/* Given e of the form t*r, (pr : t = s) and s_r is of the form (s*r),
return a proof for e = s_r */
static expr mk_ac_simp_proof(expr const & e, expr const & t, expr const & s, expr const & r, expr const & s_r, expr const & pr, expr const & assoc, expr const & comm) {
expr args[8] = {e, t, s, r, s_r, pr, assoc, comm};
return mk_macro(*g_ac_simp_macro, 8, args);
}
static name * g_ac_refl_name = nullptr;
static macro_definition * g_ac_refl_macro = nullptr;
static std::string * g_ac_refl_opcode = nullptr;
class ac_refl_macro_cell : public macro_definition_cell {
public:
virtual name get_name() const { return *g_ac_refl_name; }
virtual expr check_type(expr const & e, abstract_type_context & ctx, bool) const {
return mk_eq(ctx, macro_arg(e, 0), macro_arg(e, 2));
}
virtual unsigned trust_level() const { return AC_TRUST_LEVEL; }
virtual optional<expr> expand(expr const & e, abstract_type_context & ctx) const {
expr const & t = macro_arg(e, 0);
expr ac_t = macro_arg(e, 1);
if (is_ac_app(ac_t))
ac_t = expand_ac_core(ac_t);
expr const & op = app_fn(app_fn(ac_t));
expr const & assoc = macro_arg(e, 2);
expr const & comm = macro_arg(e, 3);
return some_expr(perm_ac(ctx, op, assoc, comm, t, ac_t));
}
virtual void write(serializer & s) const {
s.write_string(*g_ac_refl_opcode);
}
virtual bool operator==(macro_definition_cell const & other) const {
ac_refl_macro_cell const * other_ptr = dynamic_cast<ac_refl_macro_cell const *>(&other);
return other_ptr;
}
virtual unsigned hash() const { return 31; }
};
/* Given e and ac_term that is provably equal to e using AC, return a proof for e = ac_term */
static expr mk_ac_refl_proof(expr const & e, expr const & ac_term, expr const & assoc, expr const & comm) {
expr args[4] = {e, ac_term, assoc, comm};
return mk_macro(*g_ac_refl_macro, 4, args);
}
static char const * ac_var_prefix = "x_";
format theory_ac::state::pp_term(formatter const & fmt, expr const & e) const {
if (auto it = m_entries.find(e)) {
return format(ac_var_prefix) + format(it->m_idx);
} else if (is_ac_app(e)) {
format r = fmt(get_ac_app_op(e));
unsigned nargs = get_ac_app_num_args(e);
expr const * args = get_ac_app_args(e);
for (unsigned i = 0; i < nargs; i++) {
r += line() + pp_term(fmt, args[i]);
}
return group(bracket("[", r, "]"));
} else {
tout() << "pp_term: " << e << "\n";
lean_unreachable();
}
}
format theory_ac::state::pp_decl(formatter const & fmt, expr const & e) const {
lean_assert(m_entries.contains(e));
auto it = m_entries.find(e);
return group(format(ac_var_prefix) + format(it->m_idx) + line() + format(":=") + line() + fmt(e));
}
format theory_ac::state::pp_decls(formatter const & fmt) const {
format r;
bool first = true;
m_entries.for_each([&](expr const & k, entry const &) {
if (first) first = false; else r += comma() + line();
r += pp_decl(fmt, k);
});
return group(bracket("{", r, "}"));
}
format theory_ac::state::pp_R(formatter const & fmt) const {
format r;
unsigned indent = get_pp_indent(fmt.get_options());
bool first = true;
m_R.for_each([&](expr const & k, expr_pair const & p) {
format curr = pp_term(fmt, k) + line() + format("-->") + nest(indent, line() + pp_term(fmt, p.first));
if (first) first = false; else r += comma() + line();
r += group(curr);
});
return group(bracket("{", r, "}"));
}
format theory_ac::state::pp(formatter const & fmt) const {
return group(bracket("[", pp_decls(fmt) + comma() + line() + pp_R(fmt), "]"));
}
theory_ac::theory_ac(congruence_closure & cc, state & s):
m_ctx(cc.ctx()),
m_cc(cc),
@ -24,7 +398,8 @@ optional<expr> theory_ac::is_ac(expr const & e) {
if (!assoc_pr) return none_expr();
optional<expr> comm_pr = m_ac_manager.is_comm(e);
if (!comm_pr) return none_expr();
expr const & op = app_fn(app_fn(e));
expr op = app_fn(app_fn(e));
op = m_cc.get_defeq_canonizer().canonize(op);
if (auto it = m_state.m_can_ops.find(op))
return some_expr(*it);
optional<expr> found_op;
@ -42,31 +417,241 @@ optional<expr> theory_ac::is_ac(expr const & e) {
}
}
optional<expr> theory_ac::internalize(expr const & e, optional<expr> const & parent) {
expr theory_ac::convert(expr const & op, expr const & e, buffer<expr> & args) {
if (optional<expr> curr_op = is_ac(e)) {
if (op == *curr_op) {
expr arg1 = convert(op, app_arg(app_fn(e)), args);
expr arg2 = convert(op, app_arg(e), args);
return mk_app(op, arg1, arg2);
}
}
internalize_var(e);
args.push_back(e);
return e;
}
bool theory_ac::internalize_var(expr const & e) {
if (m_state.m_entries.contains(e)) return false;
m_state.m_entries.insert(e, entry(m_state.m_next_idx));
m_state.m_next_idx++;
m_cc.set_ac_var(e);
return true;
}
void theory_ac::dbg_trace_state() const {
lean_trace(name({"debug", "cc", "ac"}), scope_trace_env s(m_ctx.env(), m_ctx);
auto out = tout();
out << m_state.pp(out.get_formatter()) << "\n";);
}
void theory_ac::dbg_trace_eq(char const * header, expr const & lhs, expr const & rhs) const {
lean_trace(name({"debug", "cc", "ac"}), scope_trace_env s(m_ctx.env(), m_ctx);
auto out = tout();
auto fmt = out.get_formatter();
out << header << " " << pp_term(fmt, lhs) << " = " << pp_term(fmt, rhs) << "\n";);
}
void theory_ac::internalize(expr const & e, optional<expr> const & parent) {
auto op = is_ac(e);
if (!op) return none_expr();
if (!op) return;
optional<expr> parent_op;
if (parent) parent_op = is_ac(*parent);
if (parent_op && *op == *parent_op) return none_expr();
if (parent_op && *op == *parent_op) return;
// TODO(Leo): compute representative and initialize
expr rep = e;
if (!internalize_var(e)) return;
lean_trace(name({"cc", "ac"}), scope_trace_env s(m_ctx.env(), m_ctx);
tout() << "new term: " << rep << "\n";);
return some_expr(rep);
buffer<expr> args;
expr norm_e = convert(*op, e, args);
expr rep = mk_ac_app(*op, args);
auto ac_prs = m_state.m_op_info.find(*op);
lean_assert(ac_prs);
expr pr = mk_ac_refl_proof(norm_e, rep, ac_prs->first, ac_prs->second);
lean_trace(name({"debug", "cc", "ac"}), scope_trace_env s(m_ctx.env(), m_ctx);
auto out = tout();
out << "new term:\n" << e << "\n===>\n" << pp_term(out.get_formatter(), rep) << "\n";);
m_todo.emplace_back(e, rep, pr);
process();
dbg_trace_state();
}
void theory_ac::add_R_occ(expr const & arg, expr const & lhs, bool in_lhs) {
entry new_entry = *m_state.m_entries.find(arg);
occurrences occs = new_entry.get_R_occs(in_lhs);
occs.insert(lhs);
new_entry.set_R_occs(occs, in_lhs);
m_state.m_entries.insert(arg, new_entry);
}
void theory_ac::add_R_occs(expr const & e, expr const & lhs, bool in_lhs) {
if (is_ac_app(e)) {
unsigned nargs = get_ac_app_num_args(e);
expr const * args = get_ac_app_args(e);
add_R_occ(args[0], e, lhs);
for (unsigned i = 1; i < nargs; i++) {
if (args[i] != args[i-1])
add_R_occ(args[i], lhs, in_lhs);
}
} else {
add_R_occ(e, lhs, in_lhs);
}
}
void theory_ac::add_R_occs(expr const & lhs, expr const & rhs) {
add_R_occs(lhs, lhs, true);
add_R_occs(rhs, lhs, false);
}
optional<expr_pair> theory_ac::simplify_step(expr const & e) {
if (is_ac_app(e)) {
expr const & op = get_ac_app_op(e);
unsigned nargs = get_ac_app_num_args(e);
expr const * args = get_ac_app_args(e);
for (unsigned i = 0; i < nargs; i++) {
if (i == 0 || args[i] != args[i-1]) {
occurrences const & occs = m_state.m_entries.find(args[i])->get_R_lhs_occs();
optional<expr> t = occs.find_if([&](expr const & t) {
return is_ac_subset(t, e);
});
if (t) {
/*
e is of the form t*r, and we have t -> s with proof pr
So, the new simplified new_e is
s*r
with proof ac_simp_pr(e, t, s, s*r, pr) : e = s*r
*/
buffer<expr> new_args;
ac_diff(e, *t, new_args);
expr r = new_args.empty() ? mk_Prop() /* dummy value */ : mk_ac_app(op, new_args);
expr_pair const & p = *m_state.m_R.find(*t);
expr const & s = p.first;
expr const & pr = p.second;
ac_append(s, new_args);
expr new_e = mk_ac_app(op, new_args);
auto ac_prs = m_state.m_op_info.find(op);
lean_assert(ac_prs);
expr new_pr = mk_ac_simp_proof(e, *t, s, r, new_e, pr, ac_prs->first, ac_prs->second);
return optional<expr_pair>(mk_pair(new_e, new_pr));
}
}
}
} else if (expr_pair const * p = m_state.m_R.find(e)) {
return optional<expr_pair>(*p);
}
return optional<expr_pair>();
}
optional<expr_pair> theory_ac::simplify(expr const & e) {
optional<expr_pair> p = simplify_step(e);
if (!p) return p;
expr curr = p->first;
expr pr = p->second;
while (optional<expr_pair> p = simplify_step(curr)) {
expr new_curr = p->first;
pr = mk_eq_trans(m_ctx, e, curr, new_curr, pr, p->second);
curr = new_curr;
}
return optional<expr_pair>(mk_pair(curr, pr));
}
void theory_ac::process() {
while (!m_todo.empty()) {
expr lhs, rhs, pr;
std::tie(lhs, rhs, pr) = m_todo.back();
m_todo.pop_back();
/* Simplify lhs/rhs */
if (optional<expr_pair> p = simplify(lhs)) {
pr = mk_eq_trans(m_ctx, p->first, lhs, rhs, mk_eq_symm(m_ctx, lhs, p->first, p->second), pr);
lhs = p->first;
}
if (optional<expr_pair> p = simplify(rhs)) {
pr = mk_eq_trans(m_ctx, lhs, rhs, p->first, pr, p->second);
rhs = p->first;
}
dbg_trace_eq("after simp:", lhs, rhs);
/* Check trivial */
if (lhs == rhs) {
lean_trace(name({"debug", "cc", "ac"}), tout() << "trivial\n";);
continue;
}
if (!is_ac_app(lhs) && !is_ac_app(rhs) && m_cc.get_root(lhs) != m_cc.get_root(rhs)) {
/* Propagate new equality to congruence closure module */
m_cc.push_new_eq(lhs, rhs, mark_cc_theory_proof(pr));
}
/* Orient */
if (ac_lt(lhs, rhs)) {
pr = mk_eq_symm(m_ctx, lhs, rhs, pr);
std::swap(lhs, rhs);
}
/* Simplify (forward/backward) R (aka collapse/compose) */
// TODO(Leo)
/* Superpose */
// TODO(Leo)
/* Update R */
m_state.m_R.insert(lhs, mk_pair(rhs, pr));
add_R_occs(lhs, rhs);
}
}
void theory_ac::add_eq(expr const & e1, expr const & e2) {
lean_trace(name({"cc", "ac"}), scope_trace_env s(m_ctx.env(), m_ctx);
tout() << "new eq: " << e1 << " = " << e2 << "\n";);
// TODO(Leo)
dbg_trace_eq("new eq:", e1, e2);
m_todo.emplace_back(e1, e2, mk_delayed_cc_eq_proof(e1, e2));
process();
dbg_trace_state();
}
void initialize_theory_ac() {
register_trace_class(name({"cc", "ac"}));
register_trace_class(name({"debug", "cc", "ac"}));
g_ac_app_name = new name("ac_app");
g_ac_app_opcode = new std::string("ACApp");
g_ac_app_macro = new macro_definition(new ac_app_macro_cell());
register_macro_deserializer(*g_ac_app_opcode,
[](deserializer &, unsigned num, expr const * args) {
return mk_ac_app_core(num, args);
});
g_ac_simp_name = new name("ac_simp");
g_ac_simp_opcode = new std::string("ACSimp");
g_ac_simp_macro = new macro_definition(new ac_simp_macro_cell());
register_macro_deserializer(*g_ac_simp_opcode,
[](deserializer &, unsigned num, expr const * args) {
if (num != 8) corrupted_stream_exception();
return mk_ac_simp_proof(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]);
});
g_ac_refl_name = new name("ac_refl");
g_ac_refl_opcode = new std::string("ACRefl");
g_ac_refl_macro = new macro_definition(new ac_refl_macro_cell());
register_macro_deserializer(*g_ac_refl_opcode,
[](deserializer &, unsigned num, expr const * args) {
if (num != 4) corrupted_stream_exception();
return mk_ac_refl_proof(args[0], args[1], args[2], args[3]);
});
}
void finalize_theory_ac() {
delete g_ac_app_name;
delete g_ac_app_opcode;
delete g_ac_app_macro;
delete g_ac_simp_name;
delete g_ac_simp_opcode;
delete g_ac_simp_macro;
delete g_ac_refl_name;
delete g_ac_refl_opcode;
delete g_ac_refl_macro;
}
}

View file

@ -13,28 +13,65 @@ class congruence_closure;
/* Associativity and commutativity by completion */
class theory_ac {
public:
struct occurrences {
class occurrences {
rb_expr_tree m_occs;
unsigned m_size;
unsigned m_size{0};
public:
void insert(expr const & e) {
if (m_occs.contains(e)) return;
m_occs.insert(e);
m_size++;
}
void erase(expr const & e) {
if (m_occs.contains(e)) {
m_occs.erase(e);
m_size--;
}
}
unsigned size() const { return m_size; }
template<typename F>
optional<expr> find_if(F && f) const {
return m_occs.find_if(f);
}
};
struct entry {
/* m_expr is the term in the congruence closure
module being represented by this entry */
unsigned m_idx;
occurrences m_R_occs[2];
entry() {}
entry(unsigned idx):m_idx(idx) {}
occurrences const & get_R_occs(bool lhs) const { return m_R_occs[lhs]; }
occurrences const & get_R_lhs_occs() const { return get_R_occs(true); }
void set_R_occs(occurrences const & occs, bool lhs) { m_R_occs[lhs] = occs; }
};
typedef std::tuple<expr, expr, expr> expr_triple;
struct state {
/* Mapping from operators occurring in terms and their canonical
representation in this module */
rb_expr_map<expr> m_can_ops;
rb_expr_map<expr> m_can_ops;
/* Mapping from canonical AC operators to AC proof terms. */
rb_expr_map<pair<expr, expr>> m_op_info;
rb_expr_map<expr_pair> m_op_info;
/* rewriting rules */
unsigned m_next_idx{0};
rb_expr_map<entry> m_entries;
/* Confluent rewriting system */
rb_expr_map<expr_pair> m_R;
rb_expr_map<occurrences> m_R_lhs_occs;
rb_expr_map<occurrences> m_R_rhs_occs;
/* Mapping from cc terms and their normal form in the AC theory. */
rb_expr_map<expr_pair> m_N;
rb_expr_map<expr> m_N_inv;
rb_expr_map<occurrences> m_N_rhs_occs;
format pp_term(formatter const & fmt, expr const & e) const;
format pp_decl(formatter const & fmt, expr const & e) const;
format pp_decls(formatter const & fmt) const;
format pp_R(formatter const & fmt) const;
format pp(formatter const & fmt) const;
};
private:
@ -42,15 +79,30 @@ private:
congruence_closure & m_cc;
state & m_state;
ac_manager m_ac_manager;
buffer<expr_triple> m_todo;
optional<expr> is_ac(expr const & e);
expr convert(expr const & op, expr const & e, buffer<expr> & args);
bool internalize_var(expr const & e);
void add_R_occ(expr const & arg, expr const & lhs, bool in_lhs);
void add_R_occs(expr const & e, expr const & lhs, bool in_lhs);
void add_R_occs(expr const & lhs, expr const & rhs);
optional<expr_pair> simplify_step(expr const & e);
optional<expr_pair> simplify(expr const & e);
void process();
void dbg_trace_state() const;
void dbg_trace_eq(char const * header, expr const & lhs, expr const & rhs) const;
public:
theory_ac(congruence_closure & cc, state & s);
~theory_ac();
optional<expr> internalize(expr const & e, optional<expr> const & parent);
void internalize(expr const & e, optional<expr> const & parent);
void add_eq(expr const & e1, expr const & e2);
format pp_term(formatter const & fmt, expr const & e) const {
return m_state.pp_term(fmt, e);
}
};
void initialize_theory_ac();

View file

@ -0,0 +1,101 @@
/*
Copyright (c) 2016 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "library/annotation.h"
#include "library/util.h"
#include "library/replace_visitor.h"
#include "library/tactic/congruence/congruence_closure.h"
namespace lean {
static name * g_cc_proof_name = nullptr;
static macro_definition * g_cc_proof_macro = nullptr;
class cc_proof_macro_cell : public macro_definition_cell {
public:
virtual name get_name() const { return *g_cc_proof_name; }
virtual expr check_type(expr const & e, abstract_type_context & ctx, bool) const {
return mk_eq(ctx, macro_arg(e, 0), macro_arg(e, 1));
}
virtual optional<expr> expand(expr const &, abstract_type_context &) const {
/* This is a temporary/delayed proof step. */
lean_unreachable();
}
virtual void write(serializer &) const {
/* This is a temporary/delayed proof step. */
lean_unreachable();
}
virtual bool operator==(macro_definition_cell const & other) const {
cc_proof_macro_cell const * other_ptr = dynamic_cast<cc_proof_macro_cell const *>(&other);
return other_ptr;
}
virtual unsigned hash() const { return 23; }
};
/* Delayed (congruence closure procedure) proof.
This proof is a placeholder for the real one that is computed only if needed. */
expr mk_delayed_cc_eq_proof(expr const & e1, expr const & e2) {
expr args[2] = {e1, e2};
return mk_macro(*g_cc_proof_macro, 2, args);
}
bool is_delayed_cc_eq_proof(expr const & e) {
return is_macro(e) && dynamic_cast<cc_proof_macro_cell const *>(macro_def(e).raw());
}
static name * g_theory_proof = nullptr;
expr mark_cc_theory_proof(expr const & pr) {
return mk_annotation(*g_theory_proof, pr);
}
bool is_cc_theory_proof(expr const & e) {
return is_annotation(e, *g_theory_proof);
}
expr get_cc_theory_proof_arg(expr const & pr) {
lean_assert(is_cc_theory_proof(pr));
return get_annotation_arg(pr);
}
class expand_delayed_cc_proofs_fn : public replace_visitor {
congruence_closure const & m_cc;
expr visit_macro(expr const & e) {
if (is_delayed_cc_eq_proof(e)) {
expr const & lhs = macro_arg(e, 0);
expr const & rhs = macro_arg(e, 1);
return *m_cc.get_eq_proof(lhs, rhs);
} else {
return replace_visitor::visit_macro(e);
}
}
public:
expand_delayed_cc_proofs_fn(congruence_closure const & cc):m_cc(cc) {}
};
expr expand_delayed_cc_proofs(congruence_closure const & cc, expr const & e) {
return expand_delayed_cc_proofs_fn(cc)(e);
}
void initialize_congruence_util() {
g_cc_proof_name = new name("cc_proof");
g_cc_proof_macro = new macro_definition(new cc_proof_macro_cell());
g_theory_proof = new name("th_proof");
register_annotation(*g_theory_proof);
}
void finalize_congruence_util() {
delete g_cc_proof_macro;
delete g_cc_proof_name;
delete g_theory_proof;
}
}

View file

@ -0,0 +1,21 @@
/*
Copyright (c) 2016 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include "kernel/expr.h"
namespace lean {
expr mk_delayed_cc_eq_proof(expr const & e1, expr const & e2);
expr mark_cc_theory_proof(expr const & pr);
expr get_cc_theory_proof_arg(expr const & pr);
bool is_cc_theory_proof(expr const & e);
class congruence_closure;
expr expand_delayed_cc_proofs(congruence_closure const & cc, expr const & e);
void initialize_congruence_util();
void finalize_congruence_util();
}

View file

@ -584,6 +584,8 @@ expr mk_eq_refl(abstract_type_context & ctx, expr const & a) {
}
expr mk_eq_symm(abstract_type_context & ctx, expr const & H) {
if (is_app_of(H, get_eq_refl_name()))
return H;
expr p = ctx.whnf(ctx.infer(H));
lean_assert(is_eq(p));
expr lhs = app_arg(app_fn(p));
@ -594,6 +596,10 @@ expr mk_eq_symm(abstract_type_context & ctx, expr const & H) {
}
expr mk_eq_trans(abstract_type_context & ctx, expr const & H1, expr const & H2) {
if (is_app_of(H1, get_eq_refl_name()))
return H2;
if (is_app_of(H2, get_eq_refl_name()))
return H1;
expr p1 = ctx.whnf(ctx.infer(H1));
expr p2 = ctx.whnf(ctx.infer(H2));
lean_assert(is_eq(p1) && is_eq(p2));