From 9065cf0350fb4eaf01f1aa8205d1f7192d72a59d Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 28 Dec 2016 21:24:46 -0800 Subject: [PATCH] feat(library/tactic/congruence/theory_ac): add internalization, interface with congruence closure module, and trivial/simp/orient transitions Still missing: superpose, collapse and compose transitions. --- src/kernel/expr.h | 4 + src/library/app_builder.cpp | 28 + src/library/app_builder.h | 2 + src/library/expr_lt.h | 1 + src/library/tactic/congruence/CMakeLists.txt | 2 +- .../tactic/congruence/congruence_closure.cpp | 59 +- .../tactic/congruence/congruence_closure.h | 15 +- src/library/tactic/congruence/init_module.cpp | 3 + src/library/tactic/congruence/theory_ac.cpp | 609 +++++++++++++++++- src/library/tactic/congruence/theory_ac.h | 76 ++- src/library/tactic/congruence/util.cpp | 101 +++ src/library/tactic/congruence/util.h | 21 + src/library/util.cpp | 6 + 13 files changed, 872 insertions(+), 55 deletions(-) create mode 100644 src/library/tactic/congruence/util.cpp create mode 100644 src/library/tactic/congruence/util.h diff --git a/src/kernel/expr.h b/src/kernel/expr.h index ba6f424420..61a79b5d04 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -383,6 +383,10 @@ public: unsigned hash() const { return m_ptr->hash(); } void write(serializer & s) const { return m_ptr->write(s); } macro_definition_cell const * raw() const { return m_ptr; } + + friend bool is_eqp(macro_definition const & d1, macro_definition const & d2) { + return d1.m_ptr == d2.m_ptr; + } }; /** \brief Macro attachments */ diff --git a/src/library/app_builder.cpp b/src/library/app_builder.cpp index 642f1d5698..a01c8208fa 100644 --- a/src/library/app_builder.cpp +++ b/src/library/app_builder.cpp @@ -499,6 +499,8 @@ public: } } expr mk_eq_symm(expr const & H) { + if (is_app_of(H, get_eq_refl_name())) + return H; expr p = m_ctx.relaxed_whnf(m_ctx.infer(H)); expr A, lhs, rhs; if (!is_eq(p, A, lhs, rhs)) { @@ -508,6 +510,11 @@ public: level lvl = get_level(A); return ::lean::mk_app(mk_constant(get_eq_symm_name(), {lvl}), A, lhs, rhs, H); } + expr mk_eq_symm(expr const & a, expr const & b, expr const & H) { + expr A = m_ctx.infer(a); + level lvl = get_level(A); + return ::lean::mk_app(mk_constant(get_eq_symm_name(), {lvl}), A, a, b, H); + } expr mk_iff_symm(expr const & H) { expr p = m_ctx.infer(H); expr lhs, rhs; @@ -547,6 +554,10 @@ public: } } expr mk_eq_trans(expr const & H1, expr const & H2) { + if (is_app_of(H1, get_eq_refl_name())) + return H2; + if (is_app_of(H2, get_eq_refl_name())) + return H1; expr p1 = m_ctx.relaxed_whnf(m_ctx.infer(H1)); expr p2 = m_ctx.relaxed_whnf(m_ctx.infer(H2)); expr A, lhs1, rhs1, lhs2, rhs2; @@ -559,6 +570,15 @@ public: level lvl = get_level(A); return ::lean::mk_app({mk_constant(get_eq_trans_name(), {lvl}), A, lhs1, rhs1, rhs2, H1, H2}); } + expr mk_eq_trans(expr const & a, expr const & b, expr const & c, expr const & H1, expr const & H2) { + if (is_app_of(H1, get_eq_refl_name())) + return H2; + if (is_app_of(H2, get_eq_refl_name())) + return H1; + expr A = m_ctx.infer(a); + level lvl = get_level(A); + return ::lean::mk_app({mk_constant(get_eq_trans_name(), {lvl}), A, a, b, c, H1, H2}); + } expr mk_iff_trans(expr const & H1, expr const & H2) { expr p1 = m_ctx.infer(H1); expr p2 = m_ctx.infer(H2); @@ -820,6 +840,10 @@ expr mk_eq_symm(type_context & ctx, expr const & H) { return app_builder(ctx).mk_eq_symm(H); } +expr mk_eq_symm(type_context & ctx, expr const & a, expr const & b, expr const & H) { + return app_builder(ctx).mk_eq_symm(a, b, H); +} + expr mk_iff_symm(type_context & ctx, expr const & H) { return app_builder(ctx).mk_iff_symm(H); } @@ -836,6 +860,10 @@ expr mk_eq_trans(type_context & ctx, expr const & H1, expr const & H2) { return app_builder(ctx).mk_eq_trans(H1, H2); } +expr mk_eq_trans(type_context & ctx, expr const & a, expr const & b, expr const & c, expr const & H1, expr const & H2) { + return app_builder(ctx).mk_eq_trans(a, b, c, H1, H2); +} + expr mk_iff_trans(type_context & ctx, expr const & H1, expr const & H2) { return app_builder(ctx).mk_iff_trans(H1, H2); } diff --git a/src/library/app_builder.h b/src/library/app_builder.h index 7b80fda066..114815b8b1 100644 --- a/src/library/app_builder.h +++ b/src/library/app_builder.h @@ -91,12 +91,14 @@ expr mk_heq_refl(type_context & ctx, expr const & a); /** \brief Similar a symmetry proof for the given relation */ expr mk_symm(type_context & ctx, name const & relname, expr const & H); expr mk_eq_symm(type_context & ctx, expr const & H); +expr mk_eq_symm(type_context & ctx, expr const & a, expr const & b, expr const & H); expr mk_iff_symm(type_context & ctx, expr const & H); expr mk_heq_symm(type_context & ctx, expr const & H); /** \brief Similar a transitivity proof for the given relation */ expr mk_trans(type_context & ctx, name const & relname, expr const & H1, expr const & H2); expr mk_eq_trans(type_context & ctx, expr const & H1, expr const & H2); +expr mk_eq_trans(type_context & ctx, expr const & a, expr const & b, expr const & c, expr const & H1, expr const & H2); expr mk_iff_trans(type_context & ctx, expr const & H1, expr const & H2); expr mk_heq_trans(type_context & ctx, expr const & H1, expr const & H2); diff --git a/src/library/expr_lt.h b/src/library/expr_lt.h index ebfb121398..b0702fd420 100644 --- a/src/library/expr_lt.h +++ b/src/library/expr_lt.h @@ -18,6 +18,7 @@ namespace lean { bool is_lt(expr const & a, expr const & b, bool use_hash); /** \brief Similar to is_lt, but universe level parameter names are ignored. */ bool is_lt_no_level_params(expr const & a, expr const & b); +inline bool is_hash_lt(expr const & a, expr const & b) { return is_lt(a, b, true); } inline bool operator<(expr const & a, expr const & b) { return is_lt(a, b, true); } inline bool operator>(expr const & a, expr const & b) { return is_lt(b, a, true); } inline bool operator<=(expr const & a, expr const & b) { return !is_lt(b, a, true); } diff --git a/src/library/tactic/congruence/CMakeLists.txt b/src/library/tactic/congruence/CMakeLists.txt index 86a9cb3324..04d3b94ef3 100644 --- a/src/library/tactic/congruence/CMakeLists.txt +++ b/src/library/tactic/congruence/CMakeLists.txt @@ -1,2 +1,2 @@ add_library(congruence OBJECT congruence_closure.cpp congruence_tactics.cpp - hinst_lemmas.cpp ematch.cpp theory_ac.cpp init_module.cpp) + hinst_lemmas.cpp ematch.cpp theory_ac.cpp util.cpp init_module.cpp) diff --git a/src/library/tactic/congruence/congruence_closure.cpp b/src/library/tactic/congruence/congruence_closure.cpp index 60877ddcb7..d7b9403de1 100644 --- a/src/library/tactic/congruence/congruence_closure.cpp +++ b/src/library/tactic/congruence/congruence_closure.cpp @@ -17,6 +17,7 @@ Author: Leonardo de Moura #include "library/app_builder.h" #include "library/projection.h" #include "library/constructions/constructor.h" +#include "library/tactic/congruence/util.h" #include "library/tactic/congruence/congruence_closure.h" namespace lean { @@ -653,15 +654,15 @@ void congruence_closure::propagate_inst_implicit(expr const & e) { } } -void congruence_closure::set_ac_rep(expr const & e, expr const & ac_rep) { +void congruence_closure::set_ac_var(expr const & e) { expr e_root = get_root(e); - auto n = get_entry(e); - if (n->m_ac_rep) { - m_ac.add_eq(*n->m_ac_rep, ac_rep); + auto root_entry = get_entry(e_root); + if (root_entry->m_ac_var) { + m_ac.add_eq(*root_entry->m_ac_var, e); } else { - entry new_entry = *n; - new_entry.m_ac_rep = some_expr(ac_rep); - m_state.m_entries.insert(e_root, new_entry); + entry new_root_entry = *root_entry; + new_root_entry.m_ac_var = some_expr(e); + m_state.m_entries.insert(e_root, new_root_entry); } } @@ -807,9 +808,8 @@ void congruence_closure::internalize_core(expr const & e, bool toplevel, bool to break; }} - if (optional ac_rep = m_ac.internalize(e, parent)) { - set_ac_rep(e, *ac_rep); - } + if (m_state.m_config.m_ac) + m_ac.internalize(e, parent); } /* @@ -897,19 +897,26 @@ bool congruence_closure::has_heq_proofs(expr const & root) const { return get_entry(root)->m_heq_proofs; } +expr congruence_closure::flip_proof_core(expr const & H, bool flipped, bool heq_proofs) const { + expr new_H = H; + if (heq_proofs && is_eq(m_ctx.relaxed_whnf(m_ctx.infer(new_H)))) { + new_H = mk_heq_of_eq(m_ctx, new_H); + } + if (!flipped) { + return new_H; + } else { + return heq_proofs ? mk_heq_symm(m_ctx, new_H) : mk_eq_symm(m_ctx, new_H); + } +} + expr congruence_closure::flip_proof(expr const & H, bool flipped, bool heq_proofs) const { if (H == *g_congr_mark || H == *g_eq_true_mark || H == *g_refl_mark) { return H; + } else if (is_cc_theory_proof(H)) { + expr H1 = flip_proof_core(get_cc_theory_proof_arg(H), flipped, heq_proofs); + return mark_cc_theory_proof(H1); } else { - expr new_H = H; - if (heq_proofs && is_eq(m_ctx.relaxed_whnf(m_ctx.infer(new_H)))) { - new_H = mk_heq_of_eq(m_ctx, new_H); - } - if (!flipped) { - return new_H; - } else { - return heq_proofs ? mk_heq_symm(m_ctx, new_H) : mk_eq_symm(m_ctx, new_H); - } + return flip_proof_core(H, flipped, heq_proofs); } } @@ -1073,6 +1080,8 @@ expr congruence_closure::mk_proof(expr const & lhs, expr const & rhs, expr const expr type = heq_proofs ? mk_heq(m_ctx, lhs, rhs) : mk_eq(m_ctx, lhs, rhs); expr proof = heq_proofs ? mk_heq_refl(m_ctx, lhs) : mk_eq_refl(m_ctx, lhs); return mk_app(mk_constant(get_id_locked_name(), {mk_level_zero()}), type, proof); + } else if (is_cc_theory_proof(H)) { + return expand_delayed_cc_proofs(*this, get_cc_theory_proof_arg(H)); } else { return H; } @@ -1376,10 +1385,10 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, new_r1.m_next = r2->m_next; new_r2.m_next = r1->m_next; new_r2.m_size += r1->m_size; - optional ac_rep1 = r1->m_ac_rep; - optional ac_rep2 = r2->m_ac_rep; - if (!ac_rep2) - new_r2.m_ac_rep = ac_rep1; + optional ac_var1 = r1->m_ac_var; + optional ac_var2 = r2->m_ac_var; + if (!ac_var2) + new_r2.m_ac_var = ac_var1; if (heq_proof) new_r2.m_heq_proofs = true; m_state.m_entries.insert(e1_root, new_r1); @@ -1404,8 +1413,8 @@ void congruence_closure::add_eqv_step(expr e1, expr e2, expr const & H, m_state.m_parents.insert(e2_root, ps2); } - if (ac_rep1 && ac_rep2) - m_ac.add_eq(*ac_rep1, *ac_rep2); + if (ac_var1 && ac_var2) + m_ac.add_eq(*ac_var1, *ac_var2); // propagate new hypotheses back to current state if (!to_propagate.empty()) { diff --git a/src/library/tactic/congruence/congruence_closure.h b/src/library/tactic/congruence/congruence_closure.h index 9c3651c85d..e15a11e93b 100644 --- a/src/library/tactic/congruence/congruence_closure.h +++ b/src/library/tactic/congruence/congruence_closure.h @@ -53,8 +53,8 @@ class congruence_closure { store 'target' at 'm_target', and 'H' at 'm_proof'. Both fields are none if 'e' == m_root */ optional m_target; optional m_proof; - /* Representative used in the AC theory */ - optional m_ac_rep; + /* Variable in the AC theory. */ + optional m_ac_var; unsigned m_flipped:1; // proof has been flipped unsigned m_to_propagate:1; // must be propagated back to state when in equivalence class containing true/false unsigned m_interpreted:1; // true if the node should be viewed as an abstract value @@ -102,7 +102,8 @@ public: unsigned m_ignore_instances:1; unsigned m_values:1; unsigned m_all_ho:1; - config() { m_ignore_instances = true; m_values = true; m_all_ho = false; } + unsigned m_ac:1; + config() { m_ignore_instances = true; m_values = true; m_all_ho = false; m_ac = true; } }; class state { @@ -165,6 +166,8 @@ private: refl_info_getter m_refl_info_getter; theory_ac m_ac; + friend class theory_ac; + int compare_symm(expr lhs1, expr rhs1, expr lhs2, expr rhs2) const; unsigned symm_hash(expr const & lhs, expr const & rhs) const; optional is_binary_relation(expr const & e, expr & lhs, expr & rhs) const; @@ -185,10 +188,11 @@ private: void add_symm_congruence_table(expr const & e); void mk_entry_core(expr const & e, bool to_propagate, bool interpreted = false); void mk_entry(expr const & e, bool to_propagate); - void set_ac_rep(expr const & e, expr const & ac_rep); + void set_ac_var(expr const & e); void internalize_app(expr const & e, bool toplevel, bool to_propagate); void internalize_core(expr const & e, bool toplevel, bool to_propagate, optional const & parent); void push_todo(expr const & lhs, expr const & rhs, expr const & H, bool heq_proof); + void push_new_eq(expr const & lhs, expr const & rhs, expr const & H) { push_todo(lhs, rhs, H, false); } void push_refl_eq(expr const & lhs, expr const & rhs); void invert_trans(expr const & e, bool new_flipped, optional new_target, optional new_proof); void invert_trans(expr const & e); @@ -196,6 +200,7 @@ private: void reinsert_parents(expr const & e); void update_mt(expr const & e); bool has_heq_proofs(expr const & root) const; + expr flip_proof_core(expr const & H, bool flipped, bool heq_proofs) const; expr flip_proof(expr const & H, bool flipped, bool heq_proofs) const; optional mk_ext_hcongr_lemma(expr const & fn, unsigned nargs) const; expr mk_trans(expr const & H1, expr const & H2, bool heq_proofs) const; @@ -219,7 +224,6 @@ private: bool check_eqc(expr const & e) const; friend ext_congr_lemma_cache_ptr const & get_cache_ptr(congruence_closure const & cc); - public: congruence_closure(type_context & ctx, state & s); ~congruence_closure(); @@ -227,6 +231,7 @@ public: environment const & env() const { return m_ctx.env(); } type_context & ctx() { return m_ctx; } transparency_mode mode() const { return m_mode; } + defeq_canonizer & get_defeq_canonizer() { return m_defeq_canonizer; } /** \brief Register expression \c e in this data-structure. It creates entries for each sub-expression in \c e. diff --git a/src/library/tactic/congruence/init_module.cpp b/src/library/tactic/congruence/init_module.cpp index 0b64404f02..fa51212b42 100644 --- a/src/library/tactic/congruence/init_module.cpp +++ b/src/library/tactic/congruence/init_module.cpp @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ +#include "library/tactic/congruence/util.h" #include "library/tactic/congruence/congruence_closure.h" #include "library/tactic/congruence/congruence_tactics.h" #include "library/tactic/congruence/hinst_lemmas.h" @@ -12,6 +13,7 @@ Author: Leonardo de Moura namespace lean { void initialize_congruence_module() { + initialize_congruence_util(); initialize_congruence_closure(); initialize_congruence_tactics(); initialize_hinst_lemmas(); @@ -25,5 +27,6 @@ void finalize_congruence_module() { finalize_hinst_lemmas(); finalize_congruence_tactics(); finalize_congruence_closure(); + finalize_congruence_util(); } } diff --git a/src/library/tactic/congruence/theory_ac.cpp b/src/library/tactic/congruence/theory_ac.cpp index bebf10c720..0c70531c06 100644 --- a/src/library/tactic/congruence/theory_ac.cpp +++ b/src/library/tactic/congruence/theory_ac.cpp @@ -4,11 +4,385 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ +#include +#include #include "library/trace.h" +#include "library/app_builder.h" +#include "library/kernel_serializer.h" +#include "library/tactic/ac_tactics.h" +#include "library/tactic/congruence/util.h" #include "library/tactic/congruence/congruence_closure.h" #include "library/tactic/congruence/theory_ac.h" +/* TODO(Leo): reduce after testing */ +#define AC_TRUST_LEVEL 10000000 + namespace lean { +enum class ac_term_kind { App }; + +static name * g_ac_app_name = nullptr; +static macro_definition * g_ac_app_macro = nullptr; +static std::string * g_ac_app_opcode = nullptr; + +static expr expand_ac_core(expr const & e) { + unsigned nargs = macro_num_args(e); + unsigned i = nargs - 1; + expr const & op = macro_arg(e, i); + --i; + expr r = macro_arg(e, i); + while (i > 0) { + --i; + r = mk_app(op, macro_arg(e, i), r); + } + return r; +} + +class ac_app_macro_cell : public macro_definition_cell { +public: + virtual name get_name() const { return *g_ac_app_name; } + + virtual unsigned trust_level() const { return AC_TRUST_LEVEL; } + + virtual expr check_type(expr const & e, abstract_type_context & ctx, bool) const { + return ctx.infer(macro_arg(e, 0)); + } + + virtual optional expand(expr const & e, abstract_type_context &) const { + return some_expr(expand_ac_core(e)); + } + + virtual void write(serializer & s) const { + s.write_string(*g_ac_app_opcode); + } + + virtual bool operator==(macro_definition_cell const & other) const { + ac_app_macro_cell const * other_ptr = dynamic_cast(&other); + return other_ptr; + } + + virtual unsigned hash() const { return 37; } +}; + +static expr mk_ac_app_core(unsigned nargs, expr const * args_op) { + lean_assert(nargs >= 3); + return mk_macro(*g_ac_app_macro, nargs, args_op); +} + +static expr mk_ac_app_core(expr const & op, buffer & args) { + lean_assert(args.size() >= 2); + args.push_back(op); + expr r = mk_ac_app_core(args.size(), args.data()); + args.pop_back(); + return r; +} + +static expr mk_ac_app(expr const & op, buffer & args) { + lean_assert(args.size() > 0); + if (args.size() == 1) { + return args[0]; + } else { + std::sort(args.begin(), args.end(), is_hash_lt); + return mk_ac_app_core(op, args); + } +} + +static bool is_ac_app(expr const & e) { + return is_macro(e) && is_eqp(macro_def(e), *g_ac_app_macro); +} + +static expr const & get_ac_app_op(expr const & e) { + lean_assert(is_ac_app(e)); + return macro_arg(e, macro_num_args(e) - 1); +} + +static unsigned get_ac_app_num_args(expr const & e) { + lean_assert(is_ac_app(e)); + return macro_num_args(e) - 1; +} + +static expr const * get_ac_app_args(expr const & e) { + lean_assert(is_ac_app(e)); + return macro_args(e); +} + +/* Return true iff e1 is a "subset" of e2. + Example: The result is true for e1 := (a*a*a*b*d) and e2 := (a*a*a*a*b*b*c*d*d) */ +static bool is_ac_subset(expr const & e1, expr const & e2) { + if (is_ac_app(e1)) { + if (is_ac_app(e2) && get_ac_app_op(e1) == get_ac_app_op(e2)) { + unsigned nargs1 = get_ac_app_num_args(e1); + unsigned nargs2 = get_ac_app_num_args(e2); + if (nargs1 > nargs2) return false; + expr const * args1 = get_ac_app_args(e1); + expr const * args2 = get_ac_app_args(e2); + unsigned i1 = 0; + unsigned i2 = 0; + while (i1 < nargs1 && i2 < nargs2) { + if (args1[i1] == args2[i2]) { + i1++; + i2++; + } else if (is_hash_lt(args2[i2], args1[i1])) { + i2++; + } else { + lean_assert(is_hash_lt(args1[i1], args2[i2])); + return false; + } + } + return i1 == nargs1; + } else { + return false; + } + } else { + if (is_ac_app(e2)) { + unsigned nargs2 = get_ac_app_num_args(e2); + expr const * args2 = get_ac_app_args(e2); + return std::find(args2, args2+nargs2, e1) != args2+nargs2; + } else { + return e1 == e2; + } + } +} + +/* Store in r e1\e2. + Example: given e1 := (a*a*a*a*b*b*c*d*d*d) and e2 := (a*a*a*b*b*d), + the result is (a, c, d, d) + + \pre is_ac_subset(e2, e1) */ +static void ac_diff(expr const & e1, expr const & e2, buffer & r) { + lean_assert(is_ac_subset(e2, e1)); + if (is_ac_app(e1)) { + if (is_ac_app(e2) && get_ac_app_op(e1) == get_ac_app_op(e2)) { + unsigned nargs1 = get_ac_app_num_args(e1); + unsigned nargs2 = get_ac_app_num_args(e2); + lean_assert(nargs1 >= nargs2); + expr const * args1 = get_ac_app_args(e1); + expr const * args2 = get_ac_app_args(e2); + unsigned i2 = 0; + for (unsigned i1 = 0; i1 < nargs1; i1++) { + if (i2 == nargs2) { + r.push_back(args1[i1]); + } else if (args1[i1] == args2[i2]) { + i2++; + } else { + lean_assert(is_hash_lt(args1[i1], args2[i2])); + r.push_back(args1[i1]); + } + } + } else { + bool found = false; + unsigned nargs1 = get_ac_app_num_args(e1); + expr const * args1 = get_ac_app_args(e1); + for (unsigned i = 0; i < nargs1; i++) { + if (!found && args1[i] == e2) { + found = true; + } else { + r.push_back(args1[i]); + } + } + lean_assert(found); + } + } else { + lean_assert(!is_ac_app(e1)); + lean_assert(!is_ac_app(e2)); + lean_assert(e1 == e2); + } +} + +static void ac_append(expr const & e, buffer & r) { + if (is_ac_app(e)) { + r.append(get_ac_app_num_args(e), get_ac_app_args(e)); + } else { + r.push_back(e); + } +} + +/* lexdeg order */ +static bool ac_lt(expr const & e1, expr const & e2) { + if (is_ac_app(e1)) { + if (is_ac_app(e2) && get_ac_app_op(e1) == get_ac_app_op(e2)) { + unsigned nargs1 = get_ac_app_num_args(e1); + unsigned nargs2 = get_ac_app_num_args(e2); + if (nargs1 < nargs2) return true; + if (nargs1 > nargs2) return false; + expr const * args1 = get_ac_app_args(e1); + expr const * args2 = get_ac_app_args(e2); + for (unsigned i = 0; i < nargs1; i++) { + if (args1[i] != args2[i]) + return is_hash_lt(args1[i], args2[i]); + } + return false; + } else { + return false; + } + } else { + if (is_ac_app(e2)) { + return true; + } else { + return is_hash_lt(e1, e2); + } + } +} + +static expr expand_if_ac_app(expr const & e) { + if (is_ac_app(e)) + return expand_ac_core(e); + else + return e; +} + +static name * g_ac_simp_name = nullptr; +static macro_definition * g_ac_simp_macro = nullptr; +static std::string * g_ac_simp_opcode = nullptr; + +class ac_simp_macro_cell : public macro_definition_cell { +public: + virtual name get_name() const { return *g_ac_simp_name; } + + virtual expr check_type(expr const & e, abstract_type_context & ctx, bool) const { + return mk_eq(ctx, macro_arg(e, 0), macro_arg(e, 3)); + } + + virtual unsigned trust_level() const { return AC_TRUST_LEVEL; } + + virtual optional expand(expr const & H, abstract_type_context & ctx) const { + expr e = expand_if_ac_app(macro_arg(H, 0)); /* it is of the form t*r */ + expr t = expand_if_ac_app(macro_arg(H, 1)); + expr s = expand_if_ac_app(macro_arg(H, 2)); + expr r = expand_if_ac_app(macro_arg(H, 3)); + expr sr = expand_if_ac_app(macro_arg(H, 4)); + expr t_eq_s = expand_if_ac_app(macro_arg(H, 5)); + expr const & assoc = macro_arg(H, 6); + expr const & comm = macro_arg(H, 7); + if (e == sr) { + return some_expr(mk_eq_refl(ctx, e)); + } else if (e == t) { + lean_assert(s == sr); + return some_expr(t_eq_s); + } else { + expr op = app_fn(app_fn(e)); + expr op_r = mk_app(op, r); + expr rt = mk_app(op_r, t); + expr rs = mk_app(op, r, s); + expr rt_eq_rs = mk_congr_arg(ctx, op_r, t_eq_s); + expr e_eq_rt = perm_ac(ctx, op, assoc, comm, e, rt); + expr rs_eq_sr = perm_ac(ctx, op, assoc, comm, rs, sr); + return some_expr(mk_eq_trans(ctx, mk_eq_trans(ctx, e_eq_rt, rt_eq_rs), rs_eq_sr)); + } + } + + virtual void write(serializer & s) const { + s.write_string(*g_ac_simp_opcode); + } + + virtual bool operator==(macro_definition_cell const & other) const { + ac_simp_macro_cell const * other_ptr = dynamic_cast(&other); + return other_ptr; + } + + virtual unsigned hash() const { return 31; } +}; + +/* Given e of the form t*r, (pr : t = s) and s_r is of the form (s*r), + return a proof for e = s_r */ +static expr mk_ac_simp_proof(expr const & e, expr const & t, expr const & s, expr const & r, expr const & s_r, expr const & pr, expr const & assoc, expr const & comm) { + expr args[8] = {e, t, s, r, s_r, pr, assoc, comm}; + return mk_macro(*g_ac_simp_macro, 8, args); +} + +static name * g_ac_refl_name = nullptr; +static macro_definition * g_ac_refl_macro = nullptr; +static std::string * g_ac_refl_opcode = nullptr; + +class ac_refl_macro_cell : public macro_definition_cell { +public: + virtual name get_name() const { return *g_ac_refl_name; } + + virtual expr check_type(expr const & e, abstract_type_context & ctx, bool) const { + return mk_eq(ctx, macro_arg(e, 0), macro_arg(e, 2)); + } + + virtual unsigned trust_level() const { return AC_TRUST_LEVEL; } + + virtual optional expand(expr const & e, abstract_type_context & ctx) const { + expr const & t = macro_arg(e, 0); + expr ac_t = macro_arg(e, 1); + if (is_ac_app(ac_t)) + ac_t = expand_ac_core(ac_t); + expr const & op = app_fn(app_fn(ac_t)); + expr const & assoc = macro_arg(e, 2); + expr const & comm = macro_arg(e, 3); + return some_expr(perm_ac(ctx, op, assoc, comm, t, ac_t)); + } + + virtual void write(serializer & s) const { + s.write_string(*g_ac_refl_opcode); + } + + virtual bool operator==(macro_definition_cell const & other) const { + ac_refl_macro_cell const * other_ptr = dynamic_cast(&other); + return other_ptr; + } + + virtual unsigned hash() const { return 31; } +}; + +/* Given e and ac_term that is provably equal to e using AC, return a proof for e = ac_term */ +static expr mk_ac_refl_proof(expr const & e, expr const & ac_term, expr const & assoc, expr const & comm) { + expr args[4] = {e, ac_term, assoc, comm}; + return mk_macro(*g_ac_refl_macro, 4, args); +} + +static char const * ac_var_prefix = "x_"; + +format theory_ac::state::pp_term(formatter const & fmt, expr const & e) const { + if (auto it = m_entries.find(e)) { + return format(ac_var_prefix) + format(it->m_idx); + } else if (is_ac_app(e)) { + format r = fmt(get_ac_app_op(e)); + unsigned nargs = get_ac_app_num_args(e); + expr const * args = get_ac_app_args(e); + for (unsigned i = 0; i < nargs; i++) { + r += line() + pp_term(fmt, args[i]); + } + return group(bracket("[", r, "]")); + } else { + tout() << "pp_term: " << e << "\n"; + lean_unreachable(); + } +} + +format theory_ac::state::pp_decl(formatter const & fmt, expr const & e) const { + lean_assert(m_entries.contains(e)); + auto it = m_entries.find(e); + return group(format(ac_var_prefix) + format(it->m_idx) + line() + format(":=") + line() + fmt(e)); +} + +format theory_ac::state::pp_decls(formatter const & fmt) const { + format r; + bool first = true; + m_entries.for_each([&](expr const & k, entry const &) { + if (first) first = false; else r += comma() + line(); + r += pp_decl(fmt, k); + }); + return group(bracket("{", r, "}")); +} + +format theory_ac::state::pp_R(formatter const & fmt) const { + format r; + unsigned indent = get_pp_indent(fmt.get_options()); + bool first = true; + m_R.for_each([&](expr const & k, expr_pair const & p) { + format curr = pp_term(fmt, k) + line() + format("-->") + nest(indent, line() + pp_term(fmt, p.first)); + if (first) first = false; else r += comma() + line(); + r += group(curr); + }); + return group(bracket("{", r, "}")); +} + +format theory_ac::state::pp(formatter const & fmt) const { + return group(bracket("[", pp_decls(fmt) + comma() + line() + pp_R(fmt), "]")); +} + theory_ac::theory_ac(congruence_closure & cc, state & s): m_ctx(cc.ctx()), m_cc(cc), @@ -24,7 +398,8 @@ optional theory_ac::is_ac(expr const & e) { if (!assoc_pr) return none_expr(); optional comm_pr = m_ac_manager.is_comm(e); if (!comm_pr) return none_expr(); - expr const & op = app_fn(app_fn(e)); + expr op = app_fn(app_fn(e)); + op = m_cc.get_defeq_canonizer().canonize(op); if (auto it = m_state.m_can_ops.find(op)) return some_expr(*it); optional found_op; @@ -42,31 +417,241 @@ optional theory_ac::is_ac(expr const & e) { } } -optional theory_ac::internalize(expr const & e, optional const & parent) { +expr theory_ac::convert(expr const & op, expr const & e, buffer & args) { + if (optional curr_op = is_ac(e)) { + if (op == *curr_op) { + expr arg1 = convert(op, app_arg(app_fn(e)), args); + expr arg2 = convert(op, app_arg(e), args); + return mk_app(op, arg1, arg2); + } + } + + internalize_var(e); + args.push_back(e); + return e; +} + +bool theory_ac::internalize_var(expr const & e) { + if (m_state.m_entries.contains(e)) return false; + m_state.m_entries.insert(e, entry(m_state.m_next_idx)); + m_state.m_next_idx++; + m_cc.set_ac_var(e); + return true; +} + +void theory_ac::dbg_trace_state() const { + lean_trace(name({"debug", "cc", "ac"}), scope_trace_env s(m_ctx.env(), m_ctx); + auto out = tout(); + out << m_state.pp(out.get_formatter()) << "\n";); +} + +void theory_ac::dbg_trace_eq(char const * header, expr const & lhs, expr const & rhs) const { + lean_trace(name({"debug", "cc", "ac"}), scope_trace_env s(m_ctx.env(), m_ctx); + auto out = tout(); + auto fmt = out.get_formatter(); + out << header << " " << pp_term(fmt, lhs) << " = " << pp_term(fmt, rhs) << "\n";); +} + +void theory_ac::internalize(expr const & e, optional const & parent) { auto op = is_ac(e); - if (!op) return none_expr(); + if (!op) return; optional parent_op; if (parent) parent_op = is_ac(*parent); - if (parent_op && *op == *parent_op) return none_expr(); + if (parent_op && *op == *parent_op) return; - // TODO(Leo): compute representative and initialize - expr rep = e; + if (!internalize_var(e)) return; - lean_trace(name({"cc", "ac"}), scope_trace_env s(m_ctx.env(), m_ctx); - tout() << "new term: " << rep << "\n";); - return some_expr(rep); + buffer args; + expr norm_e = convert(*op, e, args); + expr rep = mk_ac_app(*op, args); + auto ac_prs = m_state.m_op_info.find(*op); + lean_assert(ac_prs); + expr pr = mk_ac_refl_proof(norm_e, rep, ac_prs->first, ac_prs->second); + + lean_trace(name({"debug", "cc", "ac"}), scope_trace_env s(m_ctx.env(), m_ctx); + auto out = tout(); + out << "new term:\n" << e << "\n===>\n" << pp_term(out.get_formatter(), rep) << "\n";); + + m_todo.emplace_back(e, rep, pr); + process(); + dbg_trace_state(); +} + +void theory_ac::add_R_occ(expr const & arg, expr const & lhs, bool in_lhs) { + entry new_entry = *m_state.m_entries.find(arg); + occurrences occs = new_entry.get_R_occs(in_lhs); + occs.insert(lhs); + new_entry.set_R_occs(occs, in_lhs); + m_state.m_entries.insert(arg, new_entry); +} + +void theory_ac::add_R_occs(expr const & e, expr const & lhs, bool in_lhs) { + if (is_ac_app(e)) { + unsigned nargs = get_ac_app_num_args(e); + expr const * args = get_ac_app_args(e); + add_R_occ(args[0], e, lhs); + for (unsigned i = 1; i < nargs; i++) { + if (args[i] != args[i-1]) + add_R_occ(args[i], lhs, in_lhs); + } + } else { + add_R_occ(e, lhs, in_lhs); + } +} + +void theory_ac::add_R_occs(expr const & lhs, expr const & rhs) { + add_R_occs(lhs, lhs, true); + add_R_occs(rhs, lhs, false); +} + +optional theory_ac::simplify_step(expr const & e) { + if (is_ac_app(e)) { + expr const & op = get_ac_app_op(e); + unsigned nargs = get_ac_app_num_args(e); + expr const * args = get_ac_app_args(e); + for (unsigned i = 0; i < nargs; i++) { + if (i == 0 || args[i] != args[i-1]) { + occurrences const & occs = m_state.m_entries.find(args[i])->get_R_lhs_occs(); + optional t = occs.find_if([&](expr const & t) { + return is_ac_subset(t, e); + }); + if (t) { + /* + e is of the form t*r, and we have t -> s with proof pr + So, the new simplified new_e is + s*r + with proof ac_simp_pr(e, t, s, s*r, pr) : e = s*r + */ + buffer new_args; + ac_diff(e, *t, new_args); + expr r = new_args.empty() ? mk_Prop() /* dummy value */ : mk_ac_app(op, new_args); + expr_pair const & p = *m_state.m_R.find(*t); + expr const & s = p.first; + expr const & pr = p.second; + ac_append(s, new_args); + expr new_e = mk_ac_app(op, new_args); + auto ac_prs = m_state.m_op_info.find(op); + lean_assert(ac_prs); + expr new_pr = mk_ac_simp_proof(e, *t, s, r, new_e, pr, ac_prs->first, ac_prs->second); + return optional(mk_pair(new_e, new_pr)); + } + } + } + } else if (expr_pair const * p = m_state.m_R.find(e)) { + return optional(*p); + } + return optional(); +} + +optional theory_ac::simplify(expr const & e) { + optional p = simplify_step(e); + if (!p) return p; + expr curr = p->first; + expr pr = p->second; + while (optional p = simplify_step(curr)) { + expr new_curr = p->first; + pr = mk_eq_trans(m_ctx, e, curr, new_curr, pr, p->second); + curr = new_curr; + } + return optional(mk_pair(curr, pr)); +} + +void theory_ac::process() { + while (!m_todo.empty()) { + expr lhs, rhs, pr; + std::tie(lhs, rhs, pr) = m_todo.back(); + m_todo.pop_back(); + + /* Simplify lhs/rhs */ + if (optional p = simplify(lhs)) { + pr = mk_eq_trans(m_ctx, p->first, lhs, rhs, mk_eq_symm(m_ctx, lhs, p->first, p->second), pr); + lhs = p->first; + } + if (optional p = simplify(rhs)) { + pr = mk_eq_trans(m_ctx, lhs, rhs, p->first, pr, p->second); + rhs = p->first; + } + + dbg_trace_eq("after simp:", lhs, rhs); + + /* Check trivial */ + if (lhs == rhs) { + lean_trace(name({"debug", "cc", "ac"}), tout() << "trivial\n";); + continue; + } + + if (!is_ac_app(lhs) && !is_ac_app(rhs) && m_cc.get_root(lhs) != m_cc.get_root(rhs)) { + /* Propagate new equality to congruence closure module */ + m_cc.push_new_eq(lhs, rhs, mark_cc_theory_proof(pr)); + } + + /* Orient */ + if (ac_lt(lhs, rhs)) { + pr = mk_eq_symm(m_ctx, lhs, rhs, pr); + std::swap(lhs, rhs); + } + + /* Simplify (forward/backward) R (aka collapse/compose) */ + // TODO(Leo) + + /* Superpose */ + // TODO(Leo) + + /* Update R */ + m_state.m_R.insert(lhs, mk_pair(rhs, pr)); + add_R_occs(lhs, rhs); + } } void theory_ac::add_eq(expr const & e1, expr const & e2) { - lean_trace(name({"cc", "ac"}), scope_trace_env s(m_ctx.env(), m_ctx); - tout() << "new eq: " << e1 << " = " << e2 << "\n";); - // TODO(Leo) + dbg_trace_eq("new eq:", e1, e2); + m_todo.emplace_back(e1, e2, mk_delayed_cc_eq_proof(e1, e2)); + process(); + dbg_trace_state(); } void initialize_theory_ac() { register_trace_class(name({"cc", "ac"})); + register_trace_class(name({"debug", "cc", "ac"})); + + g_ac_app_name = new name("ac_app"); + g_ac_app_opcode = new std::string("ACApp"); + g_ac_app_macro = new macro_definition(new ac_app_macro_cell()); + register_macro_deserializer(*g_ac_app_opcode, + [](deserializer &, unsigned num, expr const * args) { + return mk_ac_app_core(num, args); + }); + + g_ac_simp_name = new name("ac_simp"); + g_ac_simp_opcode = new std::string("ACSimp"); + g_ac_simp_macro = new macro_definition(new ac_simp_macro_cell()); + register_macro_deserializer(*g_ac_simp_opcode, + [](deserializer &, unsigned num, expr const * args) { + if (num != 8) corrupted_stream_exception(); + return mk_ac_simp_proof(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]); + }); + + g_ac_refl_name = new name("ac_refl"); + g_ac_refl_opcode = new std::string("ACRefl"); + g_ac_refl_macro = new macro_definition(new ac_refl_macro_cell()); + register_macro_deserializer(*g_ac_refl_opcode, + [](deserializer &, unsigned num, expr const * args) { + if (num != 4) corrupted_stream_exception(); + return mk_ac_refl_proof(args[0], args[1], args[2], args[3]); + }); } void finalize_theory_ac() { + delete g_ac_app_name; + delete g_ac_app_opcode; + delete g_ac_app_macro; + + delete g_ac_simp_name; + delete g_ac_simp_opcode; + delete g_ac_simp_macro; + + delete g_ac_refl_name; + delete g_ac_refl_opcode; + delete g_ac_refl_macro; } } diff --git a/src/library/tactic/congruence/theory_ac.h b/src/library/tactic/congruence/theory_ac.h index 392f747ae7..efbd294e98 100644 --- a/src/library/tactic/congruence/theory_ac.h +++ b/src/library/tactic/congruence/theory_ac.h @@ -13,28 +13,65 @@ class congruence_closure; /* Associativity and commutativity by completion */ class theory_ac { public: - struct occurrences { + class occurrences { rb_expr_tree m_occs; - unsigned m_size; + unsigned m_size{0}; + public: + void insert(expr const & e) { + if (m_occs.contains(e)) return; + m_occs.insert(e); + m_size++; + } + + void erase(expr const & e) { + if (m_occs.contains(e)) { + m_occs.erase(e); + m_size--; + } + } + + unsigned size() const { return m_size; } + + template + optional find_if(F && f) const { + return m_occs.find_if(f); + } }; + struct entry { + /* m_expr is the term in the congruence closure + module being represented by this entry */ + unsigned m_idx; + occurrences m_R_occs[2]; + entry() {} + entry(unsigned idx):m_idx(idx) {} + occurrences const & get_R_occs(bool lhs) const { return m_R_occs[lhs]; } + occurrences const & get_R_lhs_occs() const { return get_R_occs(true); } + void set_R_occs(occurrences const & occs, bool lhs) { m_R_occs[lhs] = occs; } + }; + + typedef std::tuple expr_triple; + struct state { /* Mapping from operators occurring in terms and their canonical representation in this module */ - rb_expr_map m_can_ops; + rb_expr_map m_can_ops; /* Mapping from canonical AC operators to AC proof terms. */ - rb_expr_map> m_op_info; + rb_expr_map m_op_info; - /* rewriting rules */ + unsigned m_next_idx{0}; + + rb_expr_map m_entries; + + /* Confluent rewriting system */ rb_expr_map m_R; - rb_expr_map m_R_lhs_occs; - rb_expr_map m_R_rhs_occs; - /* Mapping from cc terms and their normal form in the AC theory. */ - rb_expr_map m_N; - rb_expr_map m_N_inv; - rb_expr_map m_N_rhs_occs; + format pp_term(formatter const & fmt, expr const & e) const; + format pp_decl(formatter const & fmt, expr const & e) const; + format pp_decls(formatter const & fmt) const; + format pp_R(formatter const & fmt) const; + format pp(formatter const & fmt) const; }; private: @@ -42,15 +79,30 @@ private: congruence_closure & m_cc; state & m_state; ac_manager m_ac_manager; + buffer m_todo; optional is_ac(expr const & e); + expr convert(expr const & op, expr const & e, buffer & args); + bool internalize_var(expr const & e); + void add_R_occ(expr const & arg, expr const & lhs, bool in_lhs); + void add_R_occs(expr const & e, expr const & lhs, bool in_lhs); + void add_R_occs(expr const & lhs, expr const & rhs); + optional simplify_step(expr const & e); + optional simplify(expr const & e); + void process(); + void dbg_trace_state() const; + void dbg_trace_eq(char const * header, expr const & lhs, expr const & rhs) const; public: theory_ac(congruence_closure & cc, state & s); ~theory_ac(); - optional internalize(expr const & e, optional const & parent); + void internalize(expr const & e, optional const & parent); void add_eq(expr const & e1, expr const & e2); + + format pp_term(formatter const & fmt, expr const & e) const { + return m_state.pp_term(fmt, e); + } }; void initialize_theory_ac(); diff --git a/src/library/tactic/congruence/util.cpp b/src/library/tactic/congruence/util.cpp new file mode 100644 index 0000000000..fbc5d97f59 --- /dev/null +++ b/src/library/tactic/congruence/util.cpp @@ -0,0 +1,101 @@ +/* +Copyright (c) 2016 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include "library/annotation.h" +#include "library/util.h" +#include "library/replace_visitor.h" +#include "library/tactic/congruence/congruence_closure.h" + +namespace lean { +static name * g_cc_proof_name = nullptr; +static macro_definition * g_cc_proof_macro = nullptr; + +class cc_proof_macro_cell : public macro_definition_cell { +public: + virtual name get_name() const { return *g_cc_proof_name; } + + virtual expr check_type(expr const & e, abstract_type_context & ctx, bool) const { + return mk_eq(ctx, macro_arg(e, 0), macro_arg(e, 1)); + } + + virtual optional expand(expr const &, abstract_type_context &) const { + /* This is a temporary/delayed proof step. */ + lean_unreachable(); + } + + virtual void write(serializer &) const { + /* This is a temporary/delayed proof step. */ + lean_unreachable(); + } + + virtual bool operator==(macro_definition_cell const & other) const { + cc_proof_macro_cell const * other_ptr = dynamic_cast(&other); + return other_ptr; + } + + virtual unsigned hash() const { return 23; } +}; + +/* Delayed (congruence closure procedure) proof. + This proof is a placeholder for the real one that is computed only if needed. */ +expr mk_delayed_cc_eq_proof(expr const & e1, expr const & e2) { + expr args[2] = {e1, e2}; + return mk_macro(*g_cc_proof_macro, 2, args); +} + +bool is_delayed_cc_eq_proof(expr const & e) { + return is_macro(e) && dynamic_cast(macro_def(e).raw()); +} + +static name * g_theory_proof = nullptr; + +expr mark_cc_theory_proof(expr const & pr) { + return mk_annotation(*g_theory_proof, pr); +} + +bool is_cc_theory_proof(expr const & e) { + return is_annotation(e, *g_theory_proof); +} + +expr get_cc_theory_proof_arg(expr const & pr) { + lean_assert(is_cc_theory_proof(pr)); + return get_annotation_arg(pr); +} + +class expand_delayed_cc_proofs_fn : public replace_visitor { + congruence_closure const & m_cc; + + expr visit_macro(expr const & e) { + if (is_delayed_cc_eq_proof(e)) { + expr const & lhs = macro_arg(e, 0); + expr const & rhs = macro_arg(e, 1); + return *m_cc.get_eq_proof(lhs, rhs); + } else { + return replace_visitor::visit_macro(e); + } + } + +public: + expand_delayed_cc_proofs_fn(congruence_closure const & cc):m_cc(cc) {} +}; + +expr expand_delayed_cc_proofs(congruence_closure const & cc, expr const & e) { + return expand_delayed_cc_proofs_fn(cc)(e); +} + +void initialize_congruence_util() { + g_cc_proof_name = new name("cc_proof"); + g_cc_proof_macro = new macro_definition(new cc_proof_macro_cell()); + g_theory_proof = new name("th_proof"); + register_annotation(*g_theory_proof); +} + +void finalize_congruence_util() { + delete g_cc_proof_macro; + delete g_cc_proof_name; + delete g_theory_proof; +} +} diff --git a/src/library/tactic/congruence/util.h b/src/library/tactic/congruence/util.h new file mode 100644 index 0000000000..51f6f97c74 --- /dev/null +++ b/src/library/tactic/congruence/util.h @@ -0,0 +1,21 @@ +/* +Copyright (c) 2016 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include "kernel/expr.h" + +namespace lean { +expr mk_delayed_cc_eq_proof(expr const & e1, expr const & e2); +expr mark_cc_theory_proof(expr const & pr); +expr get_cc_theory_proof_arg(expr const & pr); +bool is_cc_theory_proof(expr const & e); + +class congruence_closure; +expr expand_delayed_cc_proofs(congruence_closure const & cc, expr const & e); + +void initialize_congruence_util(); +void finalize_congruence_util(); +} diff --git a/src/library/util.cpp b/src/library/util.cpp index 9532ce6e5e..29445835b5 100644 --- a/src/library/util.cpp +++ b/src/library/util.cpp @@ -584,6 +584,8 @@ expr mk_eq_refl(abstract_type_context & ctx, expr const & a) { } expr mk_eq_symm(abstract_type_context & ctx, expr const & H) { + if (is_app_of(H, get_eq_refl_name())) + return H; expr p = ctx.whnf(ctx.infer(H)); lean_assert(is_eq(p)); expr lhs = app_arg(app_fn(p)); @@ -594,6 +596,10 @@ expr mk_eq_symm(abstract_type_context & ctx, expr const & H) { } expr mk_eq_trans(abstract_type_context & ctx, expr const & H1, expr const & H2) { + if (is_app_of(H1, get_eq_refl_name())) + return H2; + if (is_app_of(H2, get_eq_refl_name())) + return H1; expr p1 = ctx.whnf(ctx.infer(H1)); expr p2 = ctx.whnf(ctx.infer(H2)); lean_assert(is_eq(p1) && is_eq(p2));