diff --git a/src/library/tactic/smt/congruence_closure.cpp b/src/library/tactic/smt/congruence_closure.cpp index 22d3ed9924..3cd9627a30 100644 --- a/src/library/tactic/smt/congruence_closure.cpp +++ b/src/library/tactic/smt/congruence_closure.cpp @@ -98,13 +98,16 @@ public: MK_THREAD_LOCAL_GET_DEF(ext_congr_lemma_cache_manager, get_clcm); -congruence_closure::congruence_closure(type_context & ctx, state & s, defeq_canonizer::state & dcs, cc_propagation_handler * phandler): +congruence_closure::congruence_closure(type_context & ctx, state & s, defeq_canonizer::state & dcs, + cc_propagation_handler * phandler, + cc_normalizer * normalizer): m_ctx(ctx), m_defeq_canonizer(ctx, dcs), m_state(s), m_cache_ptr(get_clcm().mk(ctx.env())), m_mode(ctx.mode()), m_rel_info_getter(mk_relation_info_getter(ctx.env())), m_symm_info_getter(mk_symm_info_getter(ctx.env())), m_refl_info_getter(mk_refl_info_getter(ctx.env())), m_ac(*this, m_state.m_ac_state), - m_phandler(phandler) { + m_phandler(phandler), + m_normalizer(normalizer) { } congruence_closure::~congruence_closure() { @@ -424,13 +427,18 @@ void congruence_closure::push_eq(expr const & lhs, expr const & rhs, expr const m_todo.emplace_back(lhs, rhs, H, false); } +expr congruence_closure::normalize(expr const & e) { + if (m_normalizer) + return m_normalizer->normalize(e); + else + return e; +} + void congruence_closure::process_subsingleton_elem(expr const & e) { expr type = m_ctx.infer(e); optional ss = m_ctx.mk_subsingleton_instance(type); if (!ss) return; /* type is not a subsingleton */ - /* use defeq_canonize to "normalize" instance */ - bool dummy; - type = m_defeq_canonizer.canonize(type, dummy); + type = normalize(type); /* Make sure type has been internalized */ internalize_core(type, none_expr()); /* Try to find representative */ @@ -1202,8 +1210,10 @@ optional congruence_closure::get_proof(expr const & e1, expr const & e2) c } void congruence_closure::push_subsingleton_eq(expr const & a, expr const & b) { - expr A = m_ctx.infer(a); - expr B = m_ctx.infer(b); + /* Remark: we must use normalize here because we have use it before + internalizing the types of 'a' and 'b'. */ + expr A = normalize(m_ctx.infer(a)); + expr B = normalize(m_ctx.infer(b)); /* TODO(Leo): check if the following test is a performance bottleneck */ if (m_ctx.relaxed_is_def_eq(A, B)) { /* TODO(Leo): to improve performance we can create the following proof lazily */ diff --git a/src/library/tactic/smt/congruence_closure.h b/src/library/tactic/smt/congruence_closure.h index e5f04a27d6..475b388923 100644 --- a/src/library/tactic/smt/congruence_closure.h +++ b/src/library/tactic/smt/congruence_closure.h @@ -11,6 +11,7 @@ Author: Leonardo de Moura #include "library/congr_lemma.h" #include "library/relation_manager.h" #include "library/defeq_canonizer.h" +#include "library/tactic/simp_result.h" #include "library/tactic/smt/theory_ac.h" namespace lean { @@ -26,6 +27,15 @@ public: void propagated(buffer const & p) { propagated(p.size(), p.data()); } }; +/* The congruence_closure module (optionally) uses a normalizer. + The idea is to use it (if available) to normalize auxiliary expressions + produced by internal propagation rules (e.g., subsingleton propagator). */ +class cc_normalizer { +public: + virtual ~cc_normalizer() {} + virtual expr normalize(expr const & e) = 0; +}; + class congruence_closure { /* Key for the equality congruence table. */ struct congr_key { @@ -175,6 +185,7 @@ private: refl_info_getter m_refl_info_getter; theory_ac m_ac; cc_propagation_handler * m_phandler; + cc_normalizer * m_normalizer; friend class theory_ac; int compare_symm(expr lhs1, expr rhs1, expr lhs2, expr rhs2) const; @@ -254,10 +265,13 @@ private: void add_eqv_core(expr const & lhs, expr const & rhs, expr const & H, bool heq_proof); bool check_eqc(expr const & e) const; + expr normalize(expr const & e); + friend ext_congr_lemma_cache_ptr const & get_cache_ptr(congruence_closure const & cc); public: congruence_closure(type_context & ctx, state & s, defeq_canonizer::state & dcs, - cc_propagation_handler * phandler = nullptr); + cc_propagation_handler * phandler = nullptr, + cc_normalizer * normalizer = nullptr); ~congruence_closure(); environment const & env() const { return m_ctx.env(); } diff --git a/src/library/tactic/smt/smt_state.cpp b/src/library/tactic/smt/smt_state.cpp index a944b25c21..2e78725615 100644 --- a/src/library/tactic/smt/smt_state.cpp +++ b/src/library/tactic/smt/smt_state.cpp @@ -38,8 +38,9 @@ smt_goal::smt_goal(smt_config const & cfg): smt::smt(type_context & ctx, defeq_can_state & dcs, smt_goal & g): m_ctx(ctx), + m_dcs(dcs), m_goal(g), - m_cc(ctx, m_goal.m_cc_state, dcs, this) { + m_cc(ctx, m_goal.m_cc_state, dcs, this, this) { } smt::~smt() { @@ -99,6 +100,14 @@ void smt::ematch_using(hinst_lemma const & lemma, buffer & result) { ::lean::ematch(m_ctx, m_goal.m_em_state, m_cc, lemma, false, result); } +static dsimplify_fn mk_dsimp(type_context & ctx, defeq_can_state & dcs, smt_pre_config const & cfg); + +expr smt::normalize(expr const & e) { + type_context::zeta_scope _1(m_ctx, m_goal.m_pre_config.m_zeta); + dsimplify_fn dsimp = mk_dsimp(m_ctx, m_dcs, m_goal.m_pre_config); + return dsimp(e); +} + struct vm_smt_goal : public vm_external { smt_goal m_val; vm_smt_goal(smt_goal const & v):m_val(v) {} diff --git a/src/library/tactic/smt/smt_state.h b/src/library/tactic/smt/smt_state.h index c6bba772a4..e2f4369101 100644 --- a/src/library/tactic/smt/smt_state.h +++ b/src/library/tactic/smt/smt_state.h @@ -42,16 +42,17 @@ public: void set_lemmas(hinst_lemmas const & lemmas) { m_em_state.set_lemmas(lemmas); } }; -class smt : public cc_propagation_handler { +class smt : public cc_propagation_handler, public cc_normalizer { private: type_context & m_ctx; + defeq_can_state & m_dcs; smt_goal & m_goal; congruence_closure m_cc; lbool get_value_core(expr const & e); lbool get_value(expr const & e); virtual void propagated(unsigned n, expr const * p) override; - + virtual expr normalize(expr const & e) override; public: smt(type_context & ctx, defeq_can_state & dcs, smt_goal & g); virtual ~smt();