diff --git a/src/library/unifier.cpp b/src/library/unifier.cpp index 43e2a24a24..80a7f1799e 100644 --- a/src/library/unifier.cpp +++ b/src/library/unifier.cpp @@ -377,7 +377,8 @@ struct unifier_fn { struct lazy_constraints_case_split : public case_split { lazy_list m_tail; - lazy_constraints_case_split(unifier_fn & u, justification const & j, lazy_list const & tail):case_split(u, j), m_tail(tail) {} + lazy_constraints_case_split(unifier_fn & u, justification const & j, lazy_list const & tail): + case_split(u, j), m_tail(tail) {} virtual bool next(unifier_fn & u) { return u.next_lazy_constraints_case_split(*this); } }; @@ -387,6 +388,14 @@ struct unifier_fn { virtual bool next(unifier_fn & u) { return u.next_simple_case_split(*this); } }; + struct delta_unfold_case_split : public case_split { + bool m_done; + constraint m_cnstr; + delta_unfold_case_split(unifier_fn & u, justification const & j, constraint const & c): + case_split(u, j), m_done(false), m_cnstr(c) {} + virtual bool next(unifier_fn & u) { return u.next_delta_unfold_case_split(*this); } + }; + case_split_stack m_case_splits; optional m_conflict; //!< if different from none, then there is a conflict. @@ -1172,6 +1181,37 @@ struct unifier_fn { } } + bool next_delta_unfold_case_split(delta_unfold_case_split & cs) { + if (!cs.m_done) { + cs.restore_state(*this); + cs.m_done = true; + constraint const & c = cs.m_cnstr; + expr const & lhs = cnstr_lhs_expr(c); + expr const & rhs = cnstr_rhs_expr(c); + buffer lhs_args, rhs_args; + justification j = c.get_justification(); + expr lhs_fn = get_app_rev_args(lhs, lhs_args); + expr rhs_fn = get_app_rev_args(rhs, rhs_args); + declaration d = *m_env.find(const_name(lhs_fn)); + levels lhs_lvls = const_levels(lhs_fn); + levels rhs_lvls = const_levels(lhs_fn); + bool relax = relax_main_opaque(c); + expr lhs_fn_val = instantiate_value_univ_params(d, const_levels(lhs_fn)); + expr rhs_fn_val = instantiate_value_univ_params(d, const_levels(rhs_fn)); + expr t = apply_beta(lhs_fn_val, lhs_args.size(), lhs_args.data()); + expr s = apply_beta(rhs_fn_val, rhs_args.size(), rhs_args.data()); + auto dcs = m_tc[relax]->is_def_eq(t, s, j); + if (dcs.first) { + constraints cnstrs = dcs.second.to_list(); + return process_constraints(cnstrs, mk_composite1(cs.get_jst(), mk_assumption_justification(cs.m_assumption_idx))); + } + } + // update conflict + update_conflict(mk_composite1(*m_conflict, cs.m_failed_justifications)); + pop_case_split(); + return false; + } + /** \brief Solve constraints of the form (f a_1 ... a_n) =?= (f b_1 ... b_n) where f can be expanded. We consider two possible solutions: @@ -1200,16 +1240,8 @@ struct unifier_fn { bool relax = relax_main_opaque(c); if (m_config.m_computation || module::is_definition(m_env, d.get_name()) || is_reducible_on(m_env, d.get_name())) { // add case_split for t =?= s - expr lhs_fn_val = instantiate_value_univ_params(d, const_levels(lhs_fn)); - expr rhs_fn_val = instantiate_value_univ_params(d, const_levels(rhs_fn)); - expr t = apply_beta(lhs_fn_val, lhs_args.size(), lhs_args.data()); - expr s = apply_beta(rhs_fn_val, rhs_args.size(), rhs_args.data()); - auto dcs = m_tc[relax]->is_def_eq(t, s, j); - if (dcs.first) { - // create a case split - a = mk_assumption_justification(m_next_assumption_idx); - add_case_split(std::unique_ptr(new simple_case_split(*this, j, dcs.second.to_list()))); - } + a = mk_assumption_justification(m_next_assumption_idx); + add_case_split(std::unique_ptr(new delta_unfold_case_split(*this, j, c))); } // process first case