diff --git a/src/library/idx_metavar.cpp b/src/library/idx_metavar.cpp index 380ca2854c..92290c344e 100644 --- a/src/library/idx_metavar.cpp +++ b/src/library/idx_metavar.cpp @@ -86,4 +86,16 @@ bool has_idx_metavar(expr const & e) { }); return found; } + +bool has_idx_expr_metavar(expr const & e) { + if (!has_expr_metavar(e)) + return false; + bool found = false; + for_each(e, [&](expr const & e, unsigned) { + if (found || !has_expr_metavar(e)) return false; + if (is_idx_metavar(e)) found = true; + return true; + }); + return found; +} } diff --git a/src/library/idx_metavar.h b/src/library/idx_metavar.h index 4fc538e5d2..a45855d500 100644 --- a/src/library/idx_metavar.h +++ b/src/library/idx_metavar.h @@ -29,6 +29,7 @@ unsigned to_meta_idx(expr const & e); /** \brief Return true iff \c e contains idx metavariables or universe metavariables */ bool has_idx_metavar(expr const & e); +bool has_idx_expr_metavar(expr const & e); void initialize_idx_metavar(); void finalize_idx_metavar(); diff --git a/src/library/tactic/smt/congruence_closure.h b/src/library/tactic/smt/congruence_closure.h index 475b388923..9bfc12c057 100644 --- a/src/library/tactic/smt/congruence_closure.h +++ b/src/library/tactic/smt/congruence_closure.h @@ -265,8 +265,6 @@ private: void add_eqv_core(expr const & lhs, expr const & rhs, expr const & H, bool heq_proof); bool check_eqc(expr const & e) const; - expr normalize(expr const & e); - friend ext_congr_lemma_cache_ptr const & get_cache_ptr(congruence_closure const & cc); public: congruence_closure(type_context & ctx, state & s, defeq_canonizer::state & dcs, @@ -317,8 +315,23 @@ public: optional mk_ext_congr_lemma(expr const & e) const; + optional is_ac(expr const & e) { + if (m_state.m_config.m_ac) return m_ac.is_ac(e); + else return none_expr(); + } + entry const * get_entry(expr const & e) const { return m_state.m_entries.find(e); } bool check_invariant() const { return m_state.check_invariant(); } + + expr normalize(expr const & e); + + class state_scope { + congruence_closure & m_cc; + state m_saved_state; + public: + state_scope(congruence_closure & cc):m_cc(cc), m_saved_state(cc.m_state) {} + ~state_scope() { m_cc.m_state = m_saved_state; } + }; }; typedef congruence_closure::state cc_state; diff --git a/src/library/tactic/smt/ematch.cpp b/src/library/tactic/smt/ematch.cpp index 9c7972de2c..6aaee06116 100644 --- a/src/library/tactic/smt/ematch.cpp +++ b/src/library/tactic/smt/ematch.cpp @@ -6,7 +6,9 @@ Author: Leonardo de Moura */ #include #include "util/interrupt.h" +#include "util/small_object_allocator.h" #include "library/trace.h" +#include "library/util.h" #include "library/constants.h" #include "library/app_builder.h" #include "library/fun_info.h" @@ -80,23 +82,402 @@ bool ematch_state::save_instance(expr const & lemma, buffer const & args) return save_instance(key); } -struct ematch_fn { - enum frame_kind { DefEqOnly, EqvOnly, Match, MatchSS /* match subsingleton */, Continue }; - typedef std::tuple entry; - typedef list state; - typedef list choice; +/* Allocator for ematching constraints. */ +MK_THREAD_LOCAL_GET(small_object_allocator, get_emc_allocator, "ematch constraint"); +enum class ematch_cnstr_kind { DefEqOnly, EqvOnly, Match, MatchAC, MatchSS /* match subsingleton */, Continue }; +class ematch_cnstr; +/** \brief Base class for Ematching constraints. + + Remark: these objects are thread local. So, we don't need synchronization. */ +struct ematch_cnstr_cell { + unsigned m_rc; + ematch_cnstr_kind m_kind; + + void inc_ref() { m_rc++; } + bool dec_ref_core() { lean_assert(m_rc > 0); m_rc--; return m_rc == 0; } + void dec_ref() { if (dec_ref_core()) { dealloc(); } } + void dealloc(); + ematch_cnstr_cell(ematch_cnstr_kind k):m_rc(0), m_kind(k) {} + ematch_cnstr_kind kind() const { return m_kind; } + unsigned get_rc() const { return m_rc; } +}; + +/* Ematching constraint smart pointer */ +class ematch_cnstr { + friend class ematch_cnstr_cell; + ematch_cnstr_cell * m_data; +public: + ematch_cnstr():m_data(nullptr) {} + explicit ematch_cnstr(ematch_cnstr_cell * c):m_data(c) { m_data->inc_ref(); } + ematch_cnstr(ematch_cnstr const & o):m_data(o.m_data) { m_data->inc_ref(); } + ematch_cnstr(ematch_cnstr && o):m_data(o.m_data) { o.m_data = nullptr; } + ~ematch_cnstr() { if (m_data) m_data->dec_ref(); } + operator ematch_cnstr_cell*() const { return m_data; } + + ematch_cnstr & operator=(ematch_cnstr const & s) { + if (s.m_data) s.m_data->inc_ref(); + ematch_cnstr_cell * new_data = s.m_data; + if (m_data) m_data->dec_ref(); + m_data = new_data; + return *this; + } + + ematch_cnstr & operator=(ematch_cnstr && s) { + if (m_data) m_data->dec_ref(); + m_data = s.m_data; + s.m_data = nullptr; + return *this; + } + + ematch_cnstr_kind kind() const { return m_data->kind(); } + ematch_cnstr_cell * raw() const { return m_data; } +}; + +struct ematch_eq_cnstr : public ematch_cnstr_cell { + expr m_p; + expr m_t; + ematch_eq_cnstr(ematch_cnstr_kind k, expr const & p, expr const & t): + ematch_cnstr_cell(k), m_p(p), m_t(t) {} +}; + +struct ematch_ac_cnstr : public ematch_cnstr_cell { + expr m_op; + list m_p; + list m_t; + ematch_ac_cnstr(expr const & op, list const & p, list const & t): + ematch_cnstr_cell(ematch_cnstr_kind::MatchAC), m_op(op), m_p(p), m_t(t) {} +}; + +struct ematch_continue : public ematch_cnstr_cell { + expr m_p; + ematch_continue(expr const & p): + ematch_cnstr_cell(ematch_cnstr_kind::Continue), m_p(p) {} +}; + +inline bool is_eq_cnstr(ematch_cnstr_cell const * c) { + return + c->kind() == ematch_cnstr_kind::Match || c->kind() == ematch_cnstr_kind::MatchSS || + c->kind() == ematch_cnstr_kind::DefEqOnly || c->kind() == ematch_cnstr_kind::EqvOnly; +} +static bool is_ac_cnstr(ematch_cnstr_cell const * c) { return c->kind() == ematch_cnstr_kind::MatchAC; } +static bool is_continue(ematch_cnstr_cell const * c) { return c->kind() == ematch_cnstr_kind::Continue; } + +static ematch_eq_cnstr * to_eq_cnstr(ematch_cnstr_cell * c) { lean_assert(is_eq_cnstr(c)); return static_cast(c); } +static ematch_ac_cnstr * to_ac_cnstr(ematch_cnstr_cell * c) { lean_assert(is_ac_cnstr(c)); return static_cast(c); } +static ematch_continue * to_continue(ematch_cnstr_cell * c) { lean_assert(is_continue(c)); return static_cast(c); } + +void ematch_cnstr_cell::dealloc() { + lean_assert(get_rc() == 0); + if (is_ac_cnstr(this)) { + to_ac_cnstr(this)->~ematch_ac_cnstr(); + get_emc_allocator().deallocate(sizeof(ematch_ac_cnstr), this); + } else if (is_continue(this)) { + to_continue(this)->~ematch_continue(); + get_emc_allocator().deallocate(sizeof(ematch_continue), this); + } else { + to_eq_cnstr(this)->~ematch_eq_cnstr(); + get_emc_allocator().deallocate(sizeof(ematch_eq_cnstr), this); + } +} + +static ematch_cnstr mk_eq_cnstr(ematch_cnstr_kind k, expr const & p, expr const & t) { + return ematch_cnstr(new (get_emc_allocator().allocate(sizeof(ematch_eq_cnstr))) ematch_eq_cnstr(k, p, t)); +} + +static ematch_cnstr mk_match_ac_cnstr(expr const & op, list const & p, list const & t) { + return ematch_cnstr(new (get_emc_allocator().allocate(sizeof(ematch_ac_cnstr))) ematch_ac_cnstr(op, p, t)); +} + +static ematch_cnstr mk_continue(expr const & p) { + return ematch_cnstr(new (get_emc_allocator().allocate(sizeof(ematch_continue))) ematch_continue(p)); +} + +static ematch_cnstr mk_match_eq_cnstr(expr const & p, expr const & t) { return mk_eq_cnstr(ematch_cnstr_kind::Match, p, t); } +static ematch_cnstr mk_match_ss_cnstr(expr const & p, expr const & t) { return mk_eq_cnstr(ematch_cnstr_kind::MatchSS, p, t); } +static ematch_cnstr mk_eqv_cnstr(expr const & p, expr const & t) { return mk_eq_cnstr(ematch_cnstr_kind::EqvOnly, p, t); } +static ematch_cnstr mk_defeq_cnstr(expr const & p, expr const & t) { return mk_eq_cnstr(ematch_cnstr_kind::DefEqOnly, p, t); } + +static expr const & cnstr_p(ematch_cnstr const & c) { return to_eq_cnstr(c)->m_p; } +static expr const & cnstr_t(ematch_cnstr const & c) { return to_eq_cnstr(c)->m_t; } +static expr const & cont_p(ematch_cnstr const & c) { return to_continue(c)->m_p; } +static expr const & ac_op(ematch_cnstr const & c) { return to_ac_cnstr(c)->m_op; } +static list const & ac_p(ematch_cnstr const & c) { return to_ac_cnstr(c)->m_p; } +static list const & ac_t(ematch_cnstr const & c) { return to_ac_cnstr(c)->m_t; } + +/* + Matching modulo equalities. + + This module also supports matching modulo AC. + + The procedure is (supposed to be) complete for E-matching and AC-matching. + However, it is currently incomplete for AC-E-matching. + + Here are matching problems that are not supported. + Assuming + is an AC operation. + + 1) Given { a + b = f c }, solve (?x + f ?x) =?= (a + c + b) + It misses the solution ?x := c + + 2) Given { a = a + a }, solve (?x + ?x + ?x + ?y) =?= (a + b) + It misses the solution ?x := a, ?y := b + + The following implementation is based on standard algorithms for E-matching and + AC-matching. The following extensions are supported. + + - E-matching modulo heterogeneous equalities. + Casts are automatically introduced. + Moreover, in standard E-matching, a sub-problem such as + ?x =?= t + where ?x is unassigned, is solved by assigning ?x := t. + We add the following extension when t is in a heterogeneous equivalence class. + We peek a term t_i in eqc(t) for each different type, and then create + the subproblems: + ?x := t_1 \/ ... \/ ?x := t_k + + - Uses higher-order pattern matching whenever higher-order sub-patterns + are found. Example: (?f a) =?= (g a a) + + - Subsingleton support. For example, suppose (a b : A), and A is a subsingleton. + Then, the following pattern is solved. + (f a ?x) =?= (f b c) + This is useful when we have proofs embedded in terms. + + - Equality expansion preprocessing step for AC-matching subproblems. + Given an AC-matching subproblem p =?= ...+t+... + For each term t' headed by + in eqc(t), we generate a new case: + p =?= ...+t'+... + + Limitations: + 1- A term t will be expanded at most once per AC subproblem. + Example: given {a = a + a}, and constraint (?x + ?x + ?x + ?y) =?= (a + b). + We produce two cases: + ?x + ?x + ?x + ?y =?= a + b + \/ + ?x + ?x + ?x + ?y =?= a + a + b + + 2- We do not consider subterms of the form (t+s). + Example: give {a + b = f c}, and constraint {?x + f ?x =?= a + c + b}, + this procedure will not generate the new case {?x + f ?x =?= f c + c} + by replacing (a + b) with (f c). +*/ +struct ematch_fn { + typedef list state; type_context & m_ctx; ematch_state & m_em_state; congruence_closure & m_cc; buffer & m_new_instances; state m_state; - buffer m_choice_stack; + buffer m_choice_stack; ematch_fn(type_context & ctx, ematch_state & ems, congruence_closure & cc, buffer & new_insts): m_ctx(ctx), m_em_state(ems), m_cc(cc), m_new_instances(new_insts) {} + expr instantiate_mvars(expr const & e) { + return m_ctx.instantiate_mvars(e); + } + + /* Similar to instantiate_mvars, but it makes sure the assignment at m_ctx is not modified by composition. + That is, suppose we have the assignment { ?x := f ?y, ?y := a }, and we instantiate (g ?x). + The result is (g (f a)), but this method prevents the assignment to be modified to + { ?x := f a, ?y := a } + + We need this feature for AC matching, where we want to be able to quickly detect "partially solved" + variables of the form (?x := ?y + s) where s does not contain metavariables. */ + expr safe_instantiate_mvars(expr const & e) { + m_ctx.push_scope(); + expr r = instantiate_mvars(e); + m_ctx.pop_scope(); + return r; + } + + bool is_metavar(expr const & e) { return m_ctx.is_mvar(e); } + bool is_meta(expr const & e) { return is_metavar(get_app_fn(e)); } + bool has_expr_metavar(expr const & e) { return has_idx_expr_metavar(e); } + optional is_ac(expr const & /* e */) { + // TODO(Leo): enable AC matching when it is done + return none_expr(); + // return m_cc.is_ac(e); + } + optional get_binary_op(expr const & e) { + if (is_app(e) && is_app(app_fn(e))) + return some_expr(app_fn(app_fn(e))); + else + return none_expr(); + } + + expr internalize(expr const & e) { + expr new_e = m_cc.normalize(e); + m_cc.internalize(new_e); + return new_e; + } + + bool is_ground_eq(expr const & p, expr const & t) { + lean_assert(!has_expr_metavar(p)); + lean_assert(!has_expr_metavar(t)); + return m_cc.is_eqv(p, t) || m_ctx.is_def_eq(p, t); + } + + /* Return true iff e is a metavariable, and we have an assignment of the + form e := ?m + s, where + is an AC operator, and ?m is another metavariable. */ + bool is_partially_solved(expr const & e) { + lean_assert(is_metavar(e)); + if (auto v = m_ctx.get_assignment(e)) { + return is_ac(*v) && m_ctx.is_mvar(app_arg(app_fn(*v))); + } else { + return false; + } + } + + void flat_ac(expr const & op, expr const & e, buffer & args) { + if (optional curr_op = get_binary_op(e)) { + if (m_ctx.is_def_eq(op, *curr_op)) { + flat_ac(op, app_arg(app_fn(e)), args); + flat_ac(op, app_arg(e), args); + return; + } + } + args.push_back(e); + } + + /* Cancel ground terms that occur in p_args and t_args. + Example: + Given + [?x, 0, ?y] [a, b, 0, c], + the result is: + [?x, ?y] [a, b, c] + */ + void ac_cancel_terms(buffer & p_args, buffer & t_args) { + unsigned j = 0; + for (unsigned i = 0; i < p_args.size(); i++) { + if (has_expr_metavar(p_args[i])) { + p_args[j] = p_args[i]; + j++; + } else { + expr p = internalize(p_args[i]); + unsigned k = 0; + for (; k < t_args.size(); k++) { + if (is_ground_eq(p, t_args[k])) { + break; + } + } + if (k == t_args.size()) { + p_args[j] = p; + j++; + } else { + // cancelled + t_args.erase(k); + } + } + } + p_args.shrink(j); + } + + expr mk_ac_term(expr const & op, buffer const & args) { + lean_assert(!args.empty()); + expr r = args.back(); + unsigned i = args.size() - 1; + while (i > 0) { + --i; + r = mk_app(op, args[i], r); + } + return r; + } + + expr mk_ac_term(expr const & op, list const & args) { + buffer b; + to_buffer(args, b); + return mk_ac_term(op, b); + } + + void display_ac_cnstr(io_state_stream const & out, ematch_cnstr const & c) { + expr p = mk_ac_term(ac_op(c), ac_p(c)); + expr t = mk_ac_term(ac_op(c), ac_t(c)); + auto fmt = out.get_formatter(); + format r = group(fmt(p) + line() + format("=?=") + line() + fmt(t)); + out << r; + } + + void process_new_ac_cnstr(state const & s, expr const & p, expr const & t, buffer & new_states) { + optional op = is_ac(t); + lean_assert(op); + buffer p_args, t_args; + flat_ac(*op, p, p_args); + flat_ac(*op, t, t_args); + lean_assert(t_args.size() >= 2); + if (p_args.empty()) { + /* This can happen if we fail to unify the operator in p with the one in t. */ + return; + } + lean_assert(p_args.size() >= 2); + ac_cancel_terms(p_args, t_args); + if (p_args.size() == 1 && t_args.size() == 1) { + new_states.push_back(cons(mk_match_eq_cnstr(p_args[0], t_args[0]), s)); + return; + } + list ps = to_list(p_args); + buffer new_t_args; + /* Create a family of AC-matching constraints by replacing t-arguments + with op-applications that are in the same equivalence class. + + Example: given (a = b + c) (d = e + f) and t is of the form (a + d). + expand, will add the following AC constraints + + p =?= a + d + p =?= a + e + f + p =?= b + c + d + p =?= b + c + e + f + + To avoid non termination, we unfold a t_arg at most once. + Here is an example that would produce non-termination if + we did not use unfolded. + + Given (a = a + a) and t is of the form (a + d). + We would be able to produce + + p =?= a + d + p =?= a + a + d + ... + p =?= a + ... + a + d + ... + */ + std::function + expand = [&](unsigned i, rb_expr_tree const & unfolded) { + check_system("ematching"); + if (i == t_args.size()) { + ematch_cnstr c = mk_match_ac_cnstr(*op, ps, to_list(new_t_args)); + lean_trace(name({"debug", "ematch"}), tout() << "new ac constraint: "; display_ac_cnstr(tout(), c); tout() << "\n";); + new_states.push_back(cons(c, s)); + } else { + expr const & t_arg = t_args[i]; + new_t_args.push_back(t_arg); + expand(i+1, unfolded); + new_t_args.pop_back(); + /* search for op-applications in eqc(t_arg) */ + rb_expr_tree new_unfolded = unfolded; + bool first = true; + expr it = t_arg; + do { + if (auto op2 = is_ac(it)) { + if (*op == *op2) { + unsigned sz = t_args.size(); + flat_ac(*op, it, t_args); + if (first) { + new_unfolded.insert(t_arg); + first = false; + } + expand(i+1, new_unfolded); + t_args.shrink(sz); + } + } + it = m_cc.get_next(it); + } while (it != t_arg); + } + }; + expand(0, rb_expr_tree()); + } + void push_states(buffer & new_states) { if (new_states.size() == 1) { lean_trace(name({"debug", "ematch"}), tout() << "(only one match)\n";); @@ -105,19 +486,45 @@ struct ematch_fn { lean_trace(name({"debug", "ematch"}), tout() << "# matches: " << new_states.size() << "\n";); m_state = new_states.back(); new_states.pop_back(); - choice c = to_list(new_states); - m_choice_stack.push_back(c); - m_ctx.push_scope(); + m_choice_stack.append(new_states); + for (unsigned i = 0; i < new_states.size(); i++) + m_ctx.push_scope(); + } + } + + bool ac_merge_clash_p(expr const & p, expr const & t) { + lean_assert(is_metavar(p) && is_partially_solved(p)); + tout() << "ac_merge_clash_p: " << p << " =?= " << t << "\n"; + // TODO(Leo): + lean_unreachable(); + } + + bool is_ac_eqv(expr const & p, expr const & t) { + lean_assert(is_ac(t)); + if (is_metavar(p) && is_partially_solved(p)) { + return ac_merge_clash_p(p, t); + } else { + /* When AC support is enabled, metavariables may be assigned to terms + that have not been internalized. */ + expr new_p = safe_instantiate_mvars(p); + if (!has_expr_metavar(new_p)) { + new_p = internalize(new_p); + return is_ground_eq(new_p, t); + } else { + return m_ctx.is_def_eq(new_p, t); + } } } bool is_eqv(expr const & p, expr const & t) { - if (!has_expr_metavar(p)) { - return m_cc.is_eqv(p, t) || m_ctx.is_def_eq(p, t); + if (is_ac(t)) { + return is_ac_eqv(p, t); + } else if (!has_expr_metavar(p)) { + return is_ground_eq(p, t); } else if (is_meta(p)) { expr const & m = get_app_fn(p); if (!m_ctx.is_assigned(m)) { - expr p_type = m_ctx.instantiate_mvars(m_ctx.infer(p)); + expr p_type = safe_instantiate_mvars(m_ctx.infer(p)); expr t_type = m_ctx.infer(t); if (m_ctx.is_def_eq(p_type, t_type)) { /* Types are definitionally equal. So, we just assign */ @@ -142,8 +549,8 @@ struct ematch_fn { Important: we must process arguments from left to right. Otherwise, the "trick" above will not work. */ - m_cc.internalize(p_type); - m_cc.internalize(t_type); + p_type = internalize(p_type); + t_type = internalize(t_type); if (auto H = m_cc.get_eq_proof(t_type, p_type)) { expr cast_H_t = mk_cast(m_ctx, *H, t); return m_ctx.is_def_eq(p, cast_H_t); @@ -156,12 +563,15 @@ struct ematch_fn { using cc since they contain metavariables */ return false; } + } else if (is_metavar(p) && is_partially_solved(p)) { + return ac_merge_clash_p(p, t); } else { - expr new_p = m_ctx.instantiate_mvars(p); - if (!has_expr_metavar(new_p)) - return m_cc.is_eqv(new_p, t) || m_ctx.is_def_eq(new_p, t); - else + expr new_p = safe_instantiate_mvars(p); + if (!has_expr_metavar(new_p)) { + return is_ground_eq(new_p, t); + } else { return m_ctx.is_def_eq(new_p, t); + } } } else { return m_ctx.is_def_eq(p, t); @@ -181,7 +591,7 @@ struct ematch_fn { expr it_type = m_ctx.infer(it); if (!types_seen.find(it_type)) { types_seen.insert(it_type); - new_states.emplace_back(cons(entry(EqvOnly, p, it), m_state)); + new_states.push_back(cons(mk_eqv_cnstr(p, it), m_state)); } it = m_cc.get_next(it); } while (it != t); @@ -205,19 +615,22 @@ struct ematch_fn { list sinfo = get_subsingleton_info(m_ctx, fn, t_args.size()); list const * it1 = &finfo.get_params_info(); list const *it2 = &sinfo; - buffer new_entries; + buffer new_cnstrs; for (unsigned i = 0; i < t_args.size(); i++) { if (*it1 && head(*it1).is_inst_implicit()) { - new_entries.emplace_back(DefEqOnly, p_args[i], t_args[i]); + new_cnstrs.push_back(mk_defeq_cnstr(p_args[i], t_args[i])); + lean_assert(new_cnstrs.back().kind() == ematch_cnstr_kind::DefEqOnly); } else if (*it2 && head(*it2).is_subsingleton()) { - new_entries.emplace_back(MatchSS, p_args[i], t_args[i]); + new_cnstrs.push_back(mk_match_ss_cnstr(p_args[i], t_args[i])); + lean_assert(new_cnstrs.back().kind() == ematch_cnstr_kind::MatchSS); } else { - new_entries.emplace_back(Match, p_args[i], t_args[i]); + new_cnstrs.push_back(mk_match_eq_cnstr(p_args[i], t_args[i])); + lean_assert(new_cnstrs.back().kind() == ematch_cnstr_kind::Match); } if (*it1) it1 = &tail(*it1); if (*it2) it2 = &tail(*it2); } - s = to_list(new_entries.begin(), new_entries.end(), s); + s = to_list(new_cnstrs.begin(), new_cnstrs.end(), s); return true; } else { return false; @@ -226,8 +639,8 @@ struct ematch_fn { bool process_match(expr const & p, expr const & t) { lean_trace(name({"debug", "ematch"}), - expr new_p = m_ctx.instantiate_mvars(p); - expr new_p_type = m_ctx.instantiate_mvars(m_ctx.infer(p)); + expr new_p = safe_instantiate_mvars(p); + expr new_p_type = safe_instantiate_mvars(m_ctx.infer(p)); expr t_type = m_ctx.infer(t); tout() << "try process_match: " << p << " ::= " << new_p << " : " << new_p_type << " <=?=> " << t << " : " << t_type << "\n";); @@ -264,7 +677,9 @@ struct ematch_fn { buffer new_states; for (expr const & c : candidates) { state new_state = m_state; - if (match_args(new_state, p_args, c)) { + if (is_ac(c)) { + process_new_ac_cnstr(new_state, p, t, new_states); + } else if (match_args(new_state, p_args, c)) { lean_trace(name({"debug", "ematch"}), tout() << "match: " << c << "\n";); new_states.push_back(new_state); } @@ -288,15 +703,8 @@ struct ematch_fn { }); if (new_states.empty()) { return false; - } else if (new_states.size() == 1) { - m_state = new_states[0]; - return true; } else { - m_state = new_states.back(); - new_states.pop_back(); - choice c = to_list(new_states); - m_choice_stack.push_back(c); - m_ctx.push_scope(); + push_states(new_states); return true; } } else { @@ -308,19 +716,18 @@ struct ematch_fn { typeof(p) and typeof(t) are subsingletons */ bool process_matchss(expr const & p, expr const & t) { lean_trace(name({"debug", "ematch"}), - expr new_p = m_ctx.instantiate_mvars(p); - expr new_p_type = m_ctx.instantiate_mvars(m_ctx.infer(p)); + expr new_p = safe_instantiate_mvars(p); + expr new_p_type = safe_instantiate_mvars(m_ctx.infer(p)); expr t_type = m_ctx.infer(t); tout() << "process_matchss: " << p << " ::= " << new_p << " : " << new_p_type << " <=?=> " << t << " : " << t_type << "\n";); - if (!is_metavar(p)) { /* If p is not a metavariable we simply ignore it. We should improve this case in the future. */ lean_trace(name({"debug", "ematch"}), tout() << "(p not a metavar)\n";); return true; } - expr p_type = m_ctx.instantiate_mvars(m_ctx.infer(p)); + expr p_type = safe_instantiate_mvars(m_ctx.infer(p)); expr t_type = m_ctx.infer(t); if (m_ctx.is_def_eq(p_type, t_type)) { bool success = m_ctx.is_def_eq(p, t); @@ -329,7 +736,7 @@ struct ematch_fn { return success; } else { /* Check if the types are provably equal, and cast t */ - m_cc.internalize(p_type); + p_type = internalize(p_type); if (auto H = m_cc.get_eq_proof(t_type, p_type)) { expr cast_H_t = mk_cast(m_ctx, *H, t); bool success = m_ctx.is_def_eq(p, cast_H_t); @@ -342,43 +749,61 @@ struct ematch_fn { return false; } + bool process_defeq_only(ematch_cnstr const & c) { + expr const & p = cnstr_p(c); + expr const & t = cnstr_t(c); + bool success = m_ctx.is_def_eq(p, t); + lean_trace(name({"debug", "ematch"}), + expr new_p = safe_instantiate_mvars(p); + expr new_p_type = safe_instantiate_mvars(m_ctx.infer(p)); + expr t_type = m_ctx.infer(t); + tout() << "must be def-eq: " << new_p << " : " << new_p_type + << " =?= " << t << " : " << t_type + << " ... " << (success ? "succeeded" : "failed") << "\n";); + return success; + } + + bool process_eqv_only(ematch_cnstr const & c) { + expr const & p = cnstr_p(c); + expr const & t = cnstr_t(c); + bool success = is_eqv(p, t); + lean_trace(name({"debug", "ematch"}), + expr new_p = safe_instantiate_mvars(p); + expr new_p_type = safe_instantiate_mvars(m_ctx.infer(p)); + expr t_type = m_ctx.infer(t); + tout() << "must be eqv: " << new_p << " : " << new_p_type << " =?= " + << t << " : " << t_type << " ... " << (success ? "succeeded" : "failed") << "\n";); + return success; + } + + bool process_match_ac(ematch_cnstr const & /* c */) { + // TODO(Leo) + lean_unreachable(); + } + bool is_done() const { return is_nil(m_state); } bool process_next() { lean_assert(!is_done()); - frame_kind kind; expr p, t; - std::tie(kind, p, t) = head(m_state); - m_state = tail(m_state); + /* TODO(Leo): select easy constraint first */ + ematch_cnstr c = head(m_state); + m_state = tail(m_state); - bool success; - switch (kind) { - case DefEqOnly: - success = m_ctx.is_def_eq(p, t); - lean_trace(name({"debug", "ematch"}), - expr new_p = m_ctx.instantiate_mvars(p); - expr new_p_type = m_ctx.instantiate_mvars(m_ctx.infer(p)); - expr t_type = m_ctx.infer(t); - tout() << "must be def-eq: " << new_p << " : " << new_p_type - << " =?= " << t << " : " << t_type - << " ... " << (success ? "succeeded" : "failed") << "\n";); - return success; - case Match: - return process_match(p, t); - case EqvOnly: - success = is_eqv(p, t); - lean_trace(name({"debug", "ematch"}), - expr new_p = m_ctx.instantiate_mvars(p); - expr new_p_type = m_ctx.instantiate_mvars(m_ctx.infer(p)); - expr t_type = m_ctx.infer(t); - tout() << "must be eqv: " << new_p << " : " << new_p_type << " =?= " - << t << " : " << t_type << " ... " << (success ? "succeeded" : "failed") << "\n";); - return success; - case MatchSS: - return process_matchss(p, t); - case Continue: - return process_continue(p); + switch (c.kind()) { + case ematch_cnstr_kind::DefEqOnly: + return process_defeq_only(c); + case ematch_cnstr_kind::Match: + return process_match(cnstr_p(c), cnstr_t(c)); + case ematch_cnstr_kind::EqvOnly: + return process_eqv_only(c); + case ematch_cnstr_kind::MatchSS: + return process_matchss(cnstr_p(c), cnstr_t(c)); + case ematch_cnstr_kind::Continue: + return process_continue(cont_p(c)); + case ematch_cnstr_kind::MatchAC: + return process_match_ac(c); } lean_unreachable(); } @@ -388,12 +813,8 @@ struct ematch_fn { if (m_choice_stack.empty()) return false; m_ctx.pop_scope(); - lean_assert(m_choice_stack.back()); - m_state = head(m_choice_stack.back()); - m_ctx.push_scope(); - m_choice_stack.back() = tail(m_choice_stack.back()); - if (!m_choice_stack.back()) - m_choice_stack.pop_back(); + m_state = m_choice_stack.back(); + m_choice_stack.pop_back(); return true; } @@ -427,7 +848,7 @@ struct ematch_fn { } for (expr & lemma_arg : lemma_args) { - lemma_arg = m_ctx.instantiate_mvars(lemma_arg); + lemma_arg = instantiate_mvars(lemma_arg); if (has_idx_metavar(lemma_arg)) { lean_trace(name({"debug", "ematch"}), tout() << "instantiation failure [" << lemma.m_id << "], " << @@ -440,13 +861,16 @@ struct ematch_fn { return; // already added this instance } - expr new_inst = m_ctx.instantiate_mvars(lemma.m_prop); - if (has_idx_metavar(new_inst)) + expr new_inst = instantiate_mvars(lemma.m_prop); + if (has_idx_metavar(new_inst)) { + lean_trace(name({"debug", "ematch"}), + tout() << "new instance contains unassigned metavariables\n" << new_inst << "\n";); return; // result contains temporary metavariables + } lean_trace("ematch", tout() << "instance [" << lemma.m_id << "]: " << new_inst << "\n";); - expr new_proof = m_ctx.instantiate_mvars(lemma.m_proof); + expr new_proof = instantiate_mvars(lemma.m_proof); m_new_instances.emplace_back(new_inst, new_proof); } @@ -477,78 +901,88 @@ struct ematch_fn { unsigned i = ps.size(); while (i > 1) { --i; - s = cons(entry(Continue, ps[i], expr()), s); + s = cons(mk_continue(ps[i]), s); } return s; } - void ematch_core(hinst_lemma const & lemma, state const & s, buffer const & p0_args, expr const & t) { - lean_trace(name({"debug", "ematch"}), - tout() << "ematch " << lemma.m_id << " [using] " << t << "\n";); + /* Ematch p =?= t with initial state init. p is the pattern, and t is a term. + The inital state init is used for multipatterns. + The given lemma is instantiated for each solution found. + The new instances are stored at m_new_instances. */ + void main(hinst_lemma const & lemma, state const & init, expr const & p, expr const & t) { type_context::tmp_mode_scope scope(m_ctx, lemma.m_num_uvars, lemma.m_num_mvars); + lean_assert(!has_idx_metavar(t)); clear_choice_stack(); - m_state = s; - if (!match_args(m_state, p0_args, t)) + m_state = init; + buffer p_args; + expr const & fn = get_app_args(p, p_args); + if (!m_ctx.is_def_eq(fn, get_app_fn(t))) + return; + if (!match_args(m_state, p_args, t)) return; search(lemma); } - void ematch(hinst_lemma const & lemma, multi_pattern const & mp, expr const & t) { + void ematch_term(hinst_lemma const & lemma, multi_pattern const & mp, expr const & t) { buffer ps; to_buffer(mp, ps); + /* TODO(Leo): use heuristic to select the pattern we will match first */ state init_state = mk_inital_state(ps); - buffer p0_args; - get_app_args(ps[0], p0_args); - ematch_core(lemma, init_state, p0_args, t); + main(lemma, init_state, ps[0], t); } - void ematch(hinst_lemma const & lemma, expr const & t) { + void ematch_term(hinst_lemma const & lemma, expr const & t) { for (multi_pattern const & mp : lemma.m_multi_patterns) { - ematch(lemma, mp, t); + ematch_term(lemma, mp, t); } } - void ematch_all_core(hinst_lemma const & lemma, buffer const & ps, bool filter) { - expr const & p0 = ps[0]; - buffer p0_args; - expr const & f = get_app_args(p0, p0_args); + void ematch_terms_core(hinst_lemma const & lemma, buffer const & ps, bool filter) { + expr const & fn = get_app_fn(ps[0]); unsigned gmt = m_cc.get_gmt(); state init_state = mk_inital_state(ps); - if (rb_expr_set const * s = m_em_state.get_app_map().find(head_index(f))) { + if (rb_expr_set const * s = m_em_state.get_app_map().find(head_index(fn))) { s->for_each([&](expr const & t) { if ((m_cc.is_congr_root(t) || m_cc.in_heterogeneous_eqc(t)) && (!filter || m_cc.get_mt(t) == gmt)) { - ematch_core(lemma, init_state, p0_args, t); + main(lemma, init_state, ps[0], t); } }); } } - void ematch_all(hinst_lemma const & lemma, multi_pattern const & mp, bool filter) { + /* Match internalized terms in m_em_state with the given multipatterns. + If filter is true, then we use the term modification time information + stored in the congruence closure module. Only terms with + modification time (mt) >= the global modification time (gmt) are considered. */ + void ematch_terms(hinst_lemma const & lemma, multi_pattern const & mp, bool filter) { buffer ps; to_buffer(mp, ps); if (filter) { for (unsigned i = 0; i < ps.size(); i++) { std::swap(ps[0], ps[i]); - ematch_all_core(lemma, ps, filter); + ematch_terms_core(lemma, ps, filter); std::swap(ps[0], ps[i]); } } else { - ematch_all_core(lemma, ps, filter); + ematch_terms_core(lemma, ps, filter); } } - void ematch_all(hinst_lemma const & lemma, bool filter) { + /* Match internalized terms in m_em_state with the given lemma. */ + void ematch_terms(hinst_lemma const & lemma, bool filter) { for (multi_pattern const & mp : lemma.m_multi_patterns) { - ematch_all(lemma, mp, filter); + ematch_terms(lemma, mp, filter); } } - void ematch_lemmas(hinst_lemmas const & lemmas, bool filter) { + /* Match internalized terms in m_em_state with the given lemmas. */ + void ematch_using_lemmas(hinst_lemmas const & lemmas, bool filter) { lemmas.for_each([&](hinst_lemma const & lemma) { if (!m_em_state.max_instances_exceeded()) { - ematch_all(lemma, filter); + ematch_terms(lemma, filter); } }); } @@ -556,8 +990,8 @@ struct ematch_fn { void operator()() { if (m_em_state.max_instances_exceeded()) return; - ematch_lemmas(m_em_state.get_new_lemmas(), false); - ematch_lemmas(m_em_state.get_lemmas(), true); + ematch_using_lemmas(m_em_state.get_new_lemmas(), false); + ematch_using_lemmas(m_em_state.get_lemmas(), true); m_em_state.m_lemmas.merge(m_em_state.m_new_lemmas); m_em_state.m_new_lemmas = hinst_lemmas(); m_cc.inc_gmt(); @@ -565,14 +999,17 @@ struct ematch_fn { }; void ematch(type_context & ctx, ematch_state & s, congruence_closure & cc, hinst_lemma const & lemma, expr const & t, buffer & result) { - ematch_fn(ctx, s, cc, result).ematch(lemma, t); + congruence_closure::state_scope scope(cc); + ematch_fn(ctx, s, cc, result).ematch_term(lemma, t); } void ematch(type_context & ctx, ematch_state & s, congruence_closure & cc, hinst_lemma const & lemma, bool filter, buffer & result) { - ematch_fn(ctx, s, cc, result).ematch_all(lemma, filter); + congruence_closure::state_scope scope(cc); + ematch_fn(ctx, s, cc, result).ematch_terms(lemma, filter); } void ematch(type_context & ctx, ematch_state & s, congruence_closure & cc, buffer & result) { + congruence_closure::state_scope scope(cc); ematch_fn(ctx, s, cc, result)(); } diff --git a/src/library/tactic/smt/theory_ac.h b/src/library/tactic/smt/theory_ac.h index 8d42c40b67..a241cdc23f 100644 --- a/src/library/tactic/smt/theory_ac.h +++ b/src/library/tactic/smt/theory_ac.h @@ -95,7 +95,6 @@ private: ac_manager m_ac_manager; buffer m_todo; - optional is_ac(expr const & e); expr convert(expr const & op, expr const & e, buffer & args); bool internalize_var(expr const & e); void insert_erase_R_occ(expr const & arg, expr const & lhs, bool in_lhs, bool is_insert); @@ -122,6 +121,7 @@ public: void internalize(expr const & e, optional const & parent); void add_eq(expr const & e1, expr const & e2); + optional is_ac(expr const & e); format pp_term(formatter const & fmt, expr const & e) const { return m_state.pp_term(fmt, e);