feat(library/tactic/smt/ematch): skeleton for AC support

This commit is contained in:
Leonardo de Moura 2017-01-08 23:53:57 -08:00
parent 0a6a09fb3a
commit 694bb5c7b8
5 changed files with 575 additions and 112 deletions

View file

@ -86,4 +86,16 @@ bool has_idx_metavar(expr const & e) {
});
return found;
}
bool has_idx_expr_metavar(expr const & e) {
if (!has_expr_metavar(e))
return false;
bool found = false;
for_each(e, [&](expr const & e, unsigned) {
if (found || !has_expr_metavar(e)) return false;
if (is_idx_metavar(e)) found = true;
return true;
});
return found;
}
}

View file

@ -29,6 +29,7 @@ unsigned to_meta_idx(expr const & e);
/** \brief Return true iff \c e contains idx metavariables or universe metavariables */
bool has_idx_metavar(expr const & e);
bool has_idx_expr_metavar(expr const & e);
void initialize_idx_metavar();
void finalize_idx_metavar();

View file

@ -265,8 +265,6 @@ private:
void add_eqv_core(expr const & lhs, expr const & rhs, expr const & H, bool heq_proof);
bool check_eqc(expr const & e) const;
expr normalize(expr const & e);
friend ext_congr_lemma_cache_ptr const & get_cache_ptr(congruence_closure const & cc);
public:
congruence_closure(type_context & ctx, state & s, defeq_canonizer::state & dcs,
@ -317,8 +315,23 @@ public:
optional<ext_congr_lemma> mk_ext_congr_lemma(expr const & e) const;
optional<expr> is_ac(expr const & e) {
if (m_state.m_config.m_ac) return m_ac.is_ac(e);
else return none_expr();
}
entry const * get_entry(expr const & e) const { return m_state.m_entries.find(e); }
bool check_invariant() const { return m_state.check_invariant(); }
expr normalize(expr const & e);
class state_scope {
congruence_closure & m_cc;
state m_saved_state;
public:
state_scope(congruence_closure & cc):m_cc(cc), m_saved_state(cc.m_state) {}
~state_scope() { m_cc.m_state = m_saved_state; }
};
};
typedef congruence_closure::state cc_state;

View file

@ -6,7 +6,9 @@ Author: Leonardo de Moura
*/
#include <algorithm>
#include "util/interrupt.h"
#include "util/small_object_allocator.h"
#include "library/trace.h"
#include "library/util.h"
#include "library/constants.h"
#include "library/app_builder.h"
#include "library/fun_info.h"
@ -80,23 +82,402 @@ bool ematch_state::save_instance(expr const & lemma, buffer<expr> const & args)
return save_instance(key);
}
struct ematch_fn {
enum frame_kind { DefEqOnly, EqvOnly, Match, MatchSS /* match subsingleton */, Continue };
typedef std::tuple<frame_kind, expr, expr> entry;
typedef list<entry> state;
typedef list<state> choice;
/* Allocator for ematching constraints. */
MK_THREAD_LOCAL_GET(small_object_allocator, get_emc_allocator, "ematch constraint");
enum class ematch_cnstr_kind { DefEqOnly, EqvOnly, Match, MatchAC, MatchSS /* match subsingleton */, Continue };
class ematch_cnstr;
/** \brief Base class for Ematching constraints.
Remark: these objects are thread local. So, we don't need synchronization. */
struct ematch_cnstr_cell {
unsigned m_rc;
ematch_cnstr_kind m_kind;
void inc_ref() { m_rc++; }
bool dec_ref_core() { lean_assert(m_rc > 0); m_rc--; return m_rc == 0; }
void dec_ref() { if (dec_ref_core()) { dealloc(); } }
void dealloc();
ematch_cnstr_cell(ematch_cnstr_kind k):m_rc(0), m_kind(k) {}
ematch_cnstr_kind kind() const { return m_kind; }
unsigned get_rc() const { return m_rc; }
};
/* Ematching constraint smart pointer */
class ematch_cnstr {
friend class ematch_cnstr_cell;
ematch_cnstr_cell * m_data;
public:
ematch_cnstr():m_data(nullptr) {}
explicit ematch_cnstr(ematch_cnstr_cell * c):m_data(c) { m_data->inc_ref(); }
ematch_cnstr(ematch_cnstr const & o):m_data(o.m_data) { m_data->inc_ref(); }
ematch_cnstr(ematch_cnstr && o):m_data(o.m_data) { o.m_data = nullptr; }
~ematch_cnstr() { if (m_data) m_data->dec_ref(); }
operator ematch_cnstr_cell*() const { return m_data; }
ematch_cnstr & operator=(ematch_cnstr const & s) {
if (s.m_data) s.m_data->inc_ref();
ematch_cnstr_cell * new_data = s.m_data;
if (m_data) m_data->dec_ref();
m_data = new_data;
return *this;
}
ematch_cnstr & operator=(ematch_cnstr && s) {
if (m_data) m_data->dec_ref();
m_data = s.m_data;
s.m_data = nullptr;
return *this;
}
ematch_cnstr_kind kind() const { return m_data->kind(); }
ematch_cnstr_cell * raw() const { return m_data; }
};
struct ematch_eq_cnstr : public ematch_cnstr_cell {
expr m_p;
expr m_t;
ematch_eq_cnstr(ematch_cnstr_kind k, expr const & p, expr const & t):
ematch_cnstr_cell(k), m_p(p), m_t(t) {}
};
struct ematch_ac_cnstr : public ematch_cnstr_cell {
expr m_op;
list<expr> m_p;
list<expr> m_t;
ematch_ac_cnstr(expr const & op, list<expr> const & p, list<expr> const & t):
ematch_cnstr_cell(ematch_cnstr_kind::MatchAC), m_op(op), m_p(p), m_t(t) {}
};
struct ematch_continue : public ematch_cnstr_cell {
expr m_p;
ematch_continue(expr const & p):
ematch_cnstr_cell(ematch_cnstr_kind::Continue), m_p(p) {}
};
inline bool is_eq_cnstr(ematch_cnstr_cell const * c) {
return
c->kind() == ematch_cnstr_kind::Match || c->kind() == ematch_cnstr_kind::MatchSS ||
c->kind() == ematch_cnstr_kind::DefEqOnly || c->kind() == ematch_cnstr_kind::EqvOnly;
}
static bool is_ac_cnstr(ematch_cnstr_cell const * c) { return c->kind() == ematch_cnstr_kind::MatchAC; }
static bool is_continue(ematch_cnstr_cell const * c) { return c->kind() == ematch_cnstr_kind::Continue; }
static ematch_eq_cnstr * to_eq_cnstr(ematch_cnstr_cell * c) { lean_assert(is_eq_cnstr(c)); return static_cast<ematch_eq_cnstr*>(c); }
static ematch_ac_cnstr * to_ac_cnstr(ematch_cnstr_cell * c) { lean_assert(is_ac_cnstr(c)); return static_cast<ematch_ac_cnstr*>(c); }
static ematch_continue * to_continue(ematch_cnstr_cell * c) { lean_assert(is_continue(c)); return static_cast<ematch_continue*>(c); }
void ematch_cnstr_cell::dealloc() {
lean_assert(get_rc() == 0);
if (is_ac_cnstr(this)) {
to_ac_cnstr(this)->~ematch_ac_cnstr();
get_emc_allocator().deallocate(sizeof(ematch_ac_cnstr), this);
} else if (is_continue(this)) {
to_continue(this)->~ematch_continue();
get_emc_allocator().deallocate(sizeof(ematch_continue), this);
} else {
to_eq_cnstr(this)->~ematch_eq_cnstr();
get_emc_allocator().deallocate(sizeof(ematch_eq_cnstr), this);
}
}
static ematch_cnstr mk_eq_cnstr(ematch_cnstr_kind k, expr const & p, expr const & t) {
return ematch_cnstr(new (get_emc_allocator().allocate(sizeof(ematch_eq_cnstr))) ematch_eq_cnstr(k, p, t));
}
static ematch_cnstr mk_match_ac_cnstr(expr const & op, list<expr> const & p, list<expr> const & t) {
return ematch_cnstr(new (get_emc_allocator().allocate(sizeof(ematch_ac_cnstr))) ematch_ac_cnstr(op, p, t));
}
static ematch_cnstr mk_continue(expr const & p) {
return ematch_cnstr(new (get_emc_allocator().allocate(sizeof(ematch_continue))) ematch_continue(p));
}
static ematch_cnstr mk_match_eq_cnstr(expr const & p, expr const & t) { return mk_eq_cnstr(ematch_cnstr_kind::Match, p, t); }
static ematch_cnstr mk_match_ss_cnstr(expr const & p, expr const & t) { return mk_eq_cnstr(ematch_cnstr_kind::MatchSS, p, t); }
static ematch_cnstr mk_eqv_cnstr(expr const & p, expr const & t) { return mk_eq_cnstr(ematch_cnstr_kind::EqvOnly, p, t); }
static ematch_cnstr mk_defeq_cnstr(expr const & p, expr const & t) { return mk_eq_cnstr(ematch_cnstr_kind::DefEqOnly, p, t); }
static expr const & cnstr_p(ematch_cnstr const & c) { return to_eq_cnstr(c)->m_p; }
static expr const & cnstr_t(ematch_cnstr const & c) { return to_eq_cnstr(c)->m_t; }
static expr const & cont_p(ematch_cnstr const & c) { return to_continue(c)->m_p; }
static expr const & ac_op(ematch_cnstr const & c) { return to_ac_cnstr(c)->m_op; }
static list<expr> const & ac_p(ematch_cnstr const & c) { return to_ac_cnstr(c)->m_p; }
static list<expr> const & ac_t(ematch_cnstr const & c) { return to_ac_cnstr(c)->m_t; }
/*
Matching modulo equalities.
This module also supports matching modulo AC.
The procedure is (supposed to be) complete for E-matching and AC-matching.
However, it is currently incomplete for AC-E-matching.
Here are matching problems that are not supported.
Assuming + is an AC operation.
1) Given { a + b = f c }, solve (?x + f ?x) =?= (a + c + b)
It misses the solution ?x := c
2) Given { a = a + a }, solve (?x + ?x + ?x + ?y) =?= (a + b)
It misses the solution ?x := a, ?y := b
The following implementation is based on standard algorithms for E-matching and
AC-matching. The following extensions are supported.
- E-matching modulo heterogeneous equalities.
Casts are automatically introduced.
Moreover, in standard E-matching, a sub-problem such as
?x =?= t
where ?x is unassigned, is solved by assigning ?x := t.
We add the following extension when t is in a heterogeneous equivalence class.
We peek a term t_i in eqc(t) for each different type, and then create
the subproblems:
?x := t_1 \/ ... \/ ?x := t_k
- Uses higher-order pattern matching whenever higher-order sub-patterns
are found. Example: (?f a) =?= (g a a)
- Subsingleton support. For example, suppose (a b : A), and A is a subsingleton.
Then, the following pattern is solved.
(f a ?x) =?= (f b c)
This is useful when we have proofs embedded in terms.
- Equality expansion preprocessing step for AC-matching subproblems.
Given an AC-matching subproblem p =?= ...+t+...
For each term t' headed by + in eqc(t), we generate a new case:
p =?= ...+t'+...
Limitations:
1- A term t will be expanded at most once per AC subproblem.
Example: given {a = a + a}, and constraint (?x + ?x + ?x + ?y) =?= (a + b).
We produce two cases:
?x + ?x + ?x + ?y =?= a + b
\/
?x + ?x + ?x + ?y =?= a + a + b
2- We do not consider subterms of the form (t+s).
Example: give {a + b = f c}, and constraint {?x + f ?x =?= a + c + b},
this procedure will not generate the new case {?x + f ?x =?= f c + c}
by replacing (a + b) with (f c).
*/
struct ematch_fn {
typedef list<ematch_cnstr> state;
type_context & m_ctx;
ematch_state & m_em_state;
congruence_closure & m_cc;
buffer<expr_pair> & m_new_instances;
state m_state;
buffer<choice> m_choice_stack;
buffer<state> m_choice_stack;
ematch_fn(type_context & ctx, ematch_state & ems, congruence_closure & cc, buffer<expr_pair> & new_insts):
m_ctx(ctx), m_em_state(ems), m_cc(cc), m_new_instances(new_insts) {}
expr instantiate_mvars(expr const & e) {
return m_ctx.instantiate_mvars(e);
}
/* Similar to instantiate_mvars, but it makes sure the assignment at m_ctx is not modified by composition.
That is, suppose we have the assignment { ?x := f ?y, ?y := a }, and we instantiate (g ?x).
The result is (g (f a)), but this method prevents the assignment to be modified to
{ ?x := f a, ?y := a }
We need this feature for AC matching, where we want to be able to quickly detect "partially solved"
variables of the form (?x := ?y + s) where s does not contain metavariables. */
expr safe_instantiate_mvars(expr const & e) {
m_ctx.push_scope();
expr r = instantiate_mvars(e);
m_ctx.pop_scope();
return r;
}
bool is_metavar(expr const & e) { return m_ctx.is_mvar(e); }
bool is_meta(expr const & e) { return is_metavar(get_app_fn(e)); }
bool has_expr_metavar(expr const & e) { return has_idx_expr_metavar(e); }
optional<expr> is_ac(expr const & /* e */) {
// TODO(Leo): enable AC matching when it is done
return none_expr();
// return m_cc.is_ac(e);
}
optional<expr> get_binary_op(expr const & e) {
if (is_app(e) && is_app(app_fn(e)))
return some_expr(app_fn(app_fn(e)));
else
return none_expr();
}
expr internalize(expr const & e) {
expr new_e = m_cc.normalize(e);
m_cc.internalize(new_e);
return new_e;
}
bool is_ground_eq(expr const & p, expr const & t) {
lean_assert(!has_expr_metavar(p));
lean_assert(!has_expr_metavar(t));
return m_cc.is_eqv(p, t) || m_ctx.is_def_eq(p, t);
}
/* Return true iff e is a metavariable, and we have an assignment of the
form e := ?m + s, where + is an AC operator, and ?m is another metavariable. */
bool is_partially_solved(expr const & e) {
lean_assert(is_metavar(e));
if (auto v = m_ctx.get_assignment(e)) {
return is_ac(*v) && m_ctx.is_mvar(app_arg(app_fn(*v)));
} else {
return false;
}
}
void flat_ac(expr const & op, expr const & e, buffer<expr> & args) {
if (optional<expr> curr_op = get_binary_op(e)) {
if (m_ctx.is_def_eq(op, *curr_op)) {
flat_ac(op, app_arg(app_fn(e)), args);
flat_ac(op, app_arg(e), args);
return;
}
}
args.push_back(e);
}
/* Cancel ground terms that occur in p_args and t_args.
Example:
Given
[?x, 0, ?y] [a, b, 0, c],
the result is:
[?x, ?y] [a, b, c]
*/
void ac_cancel_terms(buffer<expr> & p_args, buffer<expr> & t_args) {
unsigned j = 0;
for (unsigned i = 0; i < p_args.size(); i++) {
if (has_expr_metavar(p_args[i])) {
p_args[j] = p_args[i];
j++;
} else {
expr p = internalize(p_args[i]);
unsigned k = 0;
for (; k < t_args.size(); k++) {
if (is_ground_eq(p, t_args[k])) {
break;
}
}
if (k == t_args.size()) {
p_args[j] = p;
j++;
} else {
// cancelled
t_args.erase(k);
}
}
}
p_args.shrink(j);
}
expr mk_ac_term(expr const & op, buffer<expr> const & args) {
lean_assert(!args.empty());
expr r = args.back();
unsigned i = args.size() - 1;
while (i > 0) {
--i;
r = mk_app(op, args[i], r);
}
return r;
}
expr mk_ac_term(expr const & op, list<expr> const & args) {
buffer<expr> b;
to_buffer(args, b);
return mk_ac_term(op, b);
}
void display_ac_cnstr(io_state_stream const & out, ematch_cnstr const & c) {
expr p = mk_ac_term(ac_op(c), ac_p(c));
expr t = mk_ac_term(ac_op(c), ac_t(c));
auto fmt = out.get_formatter();
format r = group(fmt(p) + line() + format("=?=") + line() + fmt(t));
out << r;
}
void process_new_ac_cnstr(state const & s, expr const & p, expr const & t, buffer<state> & new_states) {
optional<expr> op = is_ac(t);
lean_assert(op);
buffer<expr> p_args, t_args;
flat_ac(*op, p, p_args);
flat_ac(*op, t, t_args);
lean_assert(t_args.size() >= 2);
if (p_args.empty()) {
/* This can happen if we fail to unify the operator in p with the one in t. */
return;
}
lean_assert(p_args.size() >= 2);
ac_cancel_terms(p_args, t_args);
if (p_args.size() == 1 && t_args.size() == 1) {
new_states.push_back(cons(mk_match_eq_cnstr(p_args[0], t_args[0]), s));
return;
}
list<expr> ps = to_list(p_args);
buffer<expr> new_t_args;
/* Create a family of AC-matching constraints by replacing t-arguments
with op-applications that are in the same equivalence class.
Example: given (a = b + c) (d = e + f) and t is of the form (a + d).
expand, will add the following AC constraints
p =?= a + d
p =?= a + e + f
p =?= b + c + d
p =?= b + c + e + f
To avoid non termination, we unfold a t_arg at most once.
Here is an example that would produce non-termination if
we did not use unfolded.
Given (a = a + a) and t is of the form (a + d).
We would be able to produce
p =?= a + d
p =?= a + a + d
...
p =?= a + ... + a + d
...
*/
std::function<void(unsigned, rb_expr_tree const &)>
expand = [&](unsigned i, rb_expr_tree const & unfolded) {
check_system("ematching");
if (i == t_args.size()) {
ematch_cnstr c = mk_match_ac_cnstr(*op, ps, to_list(new_t_args));
lean_trace(name({"debug", "ematch"}), tout() << "new ac constraint: "; display_ac_cnstr(tout(), c); tout() << "\n";);
new_states.push_back(cons(c, s));
} else {
expr const & t_arg = t_args[i];
new_t_args.push_back(t_arg);
expand(i+1, unfolded);
new_t_args.pop_back();
/* search for op-applications in eqc(t_arg) */
rb_expr_tree new_unfolded = unfolded;
bool first = true;
expr it = t_arg;
do {
if (auto op2 = is_ac(it)) {
if (*op == *op2) {
unsigned sz = t_args.size();
flat_ac(*op, it, t_args);
if (first) {
new_unfolded.insert(t_arg);
first = false;
}
expand(i+1, new_unfolded);
t_args.shrink(sz);
}
}
it = m_cc.get_next(it);
} while (it != t_arg);
}
};
expand(0, rb_expr_tree());
}
void push_states(buffer<state> & new_states) {
if (new_states.size() == 1) {
lean_trace(name({"debug", "ematch"}), tout() << "(only one match)\n";);
@ -105,19 +486,45 @@ struct ematch_fn {
lean_trace(name({"debug", "ematch"}), tout() << "# matches: " << new_states.size() << "\n";);
m_state = new_states.back();
new_states.pop_back();
choice c = to_list(new_states);
m_choice_stack.push_back(c);
m_ctx.push_scope();
m_choice_stack.append(new_states);
for (unsigned i = 0; i < new_states.size(); i++)
m_ctx.push_scope();
}
}
bool ac_merge_clash_p(expr const & p, expr const & t) {
lean_assert(is_metavar(p) && is_partially_solved(p));
tout() << "ac_merge_clash_p: " << p << " =?= " << t << "\n";
// TODO(Leo):
lean_unreachable();
}
bool is_ac_eqv(expr const & p, expr const & t) {
lean_assert(is_ac(t));
if (is_metavar(p) && is_partially_solved(p)) {
return ac_merge_clash_p(p, t);
} else {
/* When AC support is enabled, metavariables may be assigned to terms
that have not been internalized. */
expr new_p = safe_instantiate_mvars(p);
if (!has_expr_metavar(new_p)) {
new_p = internalize(new_p);
return is_ground_eq(new_p, t);
} else {
return m_ctx.is_def_eq(new_p, t);
}
}
}
bool is_eqv(expr const & p, expr const & t) {
if (!has_expr_metavar(p)) {
return m_cc.is_eqv(p, t) || m_ctx.is_def_eq(p, t);
if (is_ac(t)) {
return is_ac_eqv(p, t);
} else if (!has_expr_metavar(p)) {
return is_ground_eq(p, t);
} else if (is_meta(p)) {
expr const & m = get_app_fn(p);
if (!m_ctx.is_assigned(m)) {
expr p_type = m_ctx.instantiate_mvars(m_ctx.infer(p));
expr p_type = safe_instantiate_mvars(m_ctx.infer(p));
expr t_type = m_ctx.infer(t);
if (m_ctx.is_def_eq(p_type, t_type)) {
/* Types are definitionally equal. So, we just assign */
@ -142,8 +549,8 @@ struct ematch_fn {
Important: we must process arguments from left to right. Otherwise, the "trick"
above will not work.
*/
m_cc.internalize(p_type);
m_cc.internalize(t_type);
p_type = internalize(p_type);
t_type = internalize(t_type);
if (auto H = m_cc.get_eq_proof(t_type, p_type)) {
expr cast_H_t = mk_cast(m_ctx, *H, t);
return m_ctx.is_def_eq(p, cast_H_t);
@ -156,12 +563,15 @@ struct ematch_fn {
using cc since they contain metavariables */
return false;
}
} else if (is_metavar(p) && is_partially_solved(p)) {
return ac_merge_clash_p(p, t);
} else {
expr new_p = m_ctx.instantiate_mvars(p);
if (!has_expr_metavar(new_p))
return m_cc.is_eqv(new_p, t) || m_ctx.is_def_eq(new_p, t);
else
expr new_p = safe_instantiate_mvars(p);
if (!has_expr_metavar(new_p)) {
return is_ground_eq(new_p, t);
} else {
return m_ctx.is_def_eq(new_p, t);
}
}
} else {
return m_ctx.is_def_eq(p, t);
@ -181,7 +591,7 @@ struct ematch_fn {
expr it_type = m_ctx.infer(it);
if (!types_seen.find(it_type)) {
types_seen.insert(it_type);
new_states.emplace_back(cons(entry(EqvOnly, p, it), m_state));
new_states.push_back(cons(mk_eqv_cnstr(p, it), m_state));
}
it = m_cc.get_next(it);
} while (it != t);
@ -205,19 +615,22 @@ struct ematch_fn {
list<ss_param_info> sinfo = get_subsingleton_info(m_ctx, fn, t_args.size());
list<param_info> const * it1 = &finfo.get_params_info();
list<ss_param_info> const *it2 = &sinfo;
buffer<entry> new_entries;
buffer<ematch_cnstr> new_cnstrs;
for (unsigned i = 0; i < t_args.size(); i++) {
if (*it1 && head(*it1).is_inst_implicit()) {
new_entries.emplace_back(DefEqOnly, p_args[i], t_args[i]);
new_cnstrs.push_back(mk_defeq_cnstr(p_args[i], t_args[i]));
lean_assert(new_cnstrs.back().kind() == ematch_cnstr_kind::DefEqOnly);
} else if (*it2 && head(*it2).is_subsingleton()) {
new_entries.emplace_back(MatchSS, p_args[i], t_args[i]);
new_cnstrs.push_back(mk_match_ss_cnstr(p_args[i], t_args[i]));
lean_assert(new_cnstrs.back().kind() == ematch_cnstr_kind::MatchSS);
} else {
new_entries.emplace_back(Match, p_args[i], t_args[i]);
new_cnstrs.push_back(mk_match_eq_cnstr(p_args[i], t_args[i]));
lean_assert(new_cnstrs.back().kind() == ematch_cnstr_kind::Match);
}
if (*it1) it1 = &tail(*it1);
if (*it2) it2 = &tail(*it2);
}
s = to_list(new_entries.begin(), new_entries.end(), s);
s = to_list(new_cnstrs.begin(), new_cnstrs.end(), s);
return true;
} else {
return false;
@ -226,8 +639,8 @@ struct ematch_fn {
bool process_match(expr const & p, expr const & t) {
lean_trace(name({"debug", "ematch"}),
expr new_p = m_ctx.instantiate_mvars(p);
expr new_p_type = m_ctx.instantiate_mvars(m_ctx.infer(p));
expr new_p = safe_instantiate_mvars(p);
expr new_p_type = safe_instantiate_mvars(m_ctx.infer(p));
expr t_type = m_ctx.infer(t);
tout() << "try process_match: " << p << " ::= " << new_p << " : " << new_p_type << " <=?=> "
<< t << " : " << t_type << "\n";);
@ -264,7 +677,9 @@ struct ematch_fn {
buffer<state> new_states;
for (expr const & c : candidates) {
state new_state = m_state;
if (match_args(new_state, p_args, c)) {
if (is_ac(c)) {
process_new_ac_cnstr(new_state, p, t, new_states);
} else if (match_args(new_state, p_args, c)) {
lean_trace(name({"debug", "ematch"}), tout() << "match: " << c << "\n";);
new_states.push_back(new_state);
}
@ -288,15 +703,8 @@ struct ematch_fn {
});
if (new_states.empty()) {
return false;
} else if (new_states.size() == 1) {
m_state = new_states[0];
return true;
} else {
m_state = new_states.back();
new_states.pop_back();
choice c = to_list(new_states);
m_choice_stack.push_back(c);
m_ctx.push_scope();
push_states(new_states);
return true;
}
} else {
@ -308,19 +716,18 @@ struct ematch_fn {
typeof(p) and typeof(t) are subsingletons */
bool process_matchss(expr const & p, expr const & t) {
lean_trace(name({"debug", "ematch"}),
expr new_p = m_ctx.instantiate_mvars(p);
expr new_p_type = m_ctx.instantiate_mvars(m_ctx.infer(p));
expr new_p = safe_instantiate_mvars(p);
expr new_p_type = safe_instantiate_mvars(m_ctx.infer(p));
expr t_type = m_ctx.infer(t);
tout() << "process_matchss: " << p << " ::= " << new_p << " : " << new_p_type << " <=?=> "
<< t << " : " << t_type << "\n";);
if (!is_metavar(p)) {
/* If p is not a metavariable we simply ignore it.
We should improve this case in the future. */
lean_trace(name({"debug", "ematch"}), tout() << "(p not a metavar)\n";);
return true;
}
expr p_type = m_ctx.instantiate_mvars(m_ctx.infer(p));
expr p_type = safe_instantiate_mvars(m_ctx.infer(p));
expr t_type = m_ctx.infer(t);
if (m_ctx.is_def_eq(p_type, t_type)) {
bool success = m_ctx.is_def_eq(p, t);
@ -329,7 +736,7 @@ struct ematch_fn {
return success;
} else {
/* Check if the types are provably equal, and cast t */
m_cc.internalize(p_type);
p_type = internalize(p_type);
if (auto H = m_cc.get_eq_proof(t_type, p_type)) {
expr cast_H_t = mk_cast(m_ctx, *H, t);
bool success = m_ctx.is_def_eq(p, cast_H_t);
@ -342,43 +749,61 @@ struct ematch_fn {
return false;
}
bool process_defeq_only(ematch_cnstr const & c) {
expr const & p = cnstr_p(c);
expr const & t = cnstr_t(c);
bool success = m_ctx.is_def_eq(p, t);
lean_trace(name({"debug", "ematch"}),
expr new_p = safe_instantiate_mvars(p);
expr new_p_type = safe_instantiate_mvars(m_ctx.infer(p));
expr t_type = m_ctx.infer(t);
tout() << "must be def-eq: " << new_p << " : " << new_p_type
<< " =?= " << t << " : " << t_type
<< " ... " << (success ? "succeeded" : "failed") << "\n";);
return success;
}
bool process_eqv_only(ematch_cnstr const & c) {
expr const & p = cnstr_p(c);
expr const & t = cnstr_t(c);
bool success = is_eqv(p, t);
lean_trace(name({"debug", "ematch"}),
expr new_p = safe_instantiate_mvars(p);
expr new_p_type = safe_instantiate_mvars(m_ctx.infer(p));
expr t_type = m_ctx.infer(t);
tout() << "must be eqv: " << new_p << " : " << new_p_type << " =?= "
<< t << " : " << t_type << " ... " << (success ? "succeeded" : "failed") << "\n";);
return success;
}
bool process_match_ac(ematch_cnstr const & /* c */) {
// TODO(Leo)
lean_unreachable();
}
bool is_done() const {
return is_nil(m_state);
}
bool process_next() {
lean_assert(!is_done());
frame_kind kind; expr p, t;
std::tie(kind, p, t) = head(m_state);
m_state = tail(m_state);
/* TODO(Leo): select easy constraint first */
ematch_cnstr c = head(m_state);
m_state = tail(m_state);
bool success;
switch (kind) {
case DefEqOnly:
success = m_ctx.is_def_eq(p, t);
lean_trace(name({"debug", "ematch"}),
expr new_p = m_ctx.instantiate_mvars(p);
expr new_p_type = m_ctx.instantiate_mvars(m_ctx.infer(p));
expr t_type = m_ctx.infer(t);
tout() << "must be def-eq: " << new_p << " : " << new_p_type
<< " =?= " << t << " : " << t_type
<< " ... " << (success ? "succeeded" : "failed") << "\n";);
return success;
case Match:
return process_match(p, t);
case EqvOnly:
success = is_eqv(p, t);
lean_trace(name({"debug", "ematch"}),
expr new_p = m_ctx.instantiate_mvars(p);
expr new_p_type = m_ctx.instantiate_mvars(m_ctx.infer(p));
expr t_type = m_ctx.infer(t);
tout() << "must be eqv: " << new_p << " : " << new_p_type << " =?= "
<< t << " : " << t_type << " ... " << (success ? "succeeded" : "failed") << "\n";);
return success;
case MatchSS:
return process_matchss(p, t);
case Continue:
return process_continue(p);
switch (c.kind()) {
case ematch_cnstr_kind::DefEqOnly:
return process_defeq_only(c);
case ematch_cnstr_kind::Match:
return process_match(cnstr_p(c), cnstr_t(c));
case ematch_cnstr_kind::EqvOnly:
return process_eqv_only(c);
case ematch_cnstr_kind::MatchSS:
return process_matchss(cnstr_p(c), cnstr_t(c));
case ematch_cnstr_kind::Continue:
return process_continue(cont_p(c));
case ematch_cnstr_kind::MatchAC:
return process_match_ac(c);
}
lean_unreachable();
}
@ -388,12 +813,8 @@ struct ematch_fn {
if (m_choice_stack.empty())
return false;
m_ctx.pop_scope();
lean_assert(m_choice_stack.back());
m_state = head(m_choice_stack.back());
m_ctx.push_scope();
m_choice_stack.back() = tail(m_choice_stack.back());
if (!m_choice_stack.back())
m_choice_stack.pop_back();
m_state = m_choice_stack.back();
m_choice_stack.pop_back();
return true;
}
@ -427,7 +848,7 @@ struct ematch_fn {
}
for (expr & lemma_arg : lemma_args) {
lemma_arg = m_ctx.instantiate_mvars(lemma_arg);
lemma_arg = instantiate_mvars(lemma_arg);
if (has_idx_metavar(lemma_arg)) {
lean_trace(name({"debug", "ematch"}),
tout() << "instantiation failure [" << lemma.m_id << "], " <<
@ -440,13 +861,16 @@ struct ematch_fn {
return; // already added this instance
}
expr new_inst = m_ctx.instantiate_mvars(lemma.m_prop);
if (has_idx_metavar(new_inst))
expr new_inst = instantiate_mvars(lemma.m_prop);
if (has_idx_metavar(new_inst)) {
lean_trace(name({"debug", "ematch"}),
tout() << "new instance contains unassigned metavariables\n" << new_inst << "\n";);
return; // result contains temporary metavariables
}
lean_trace("ematch",
tout() << "instance [" << lemma.m_id << "]: " << new_inst << "\n";);
expr new_proof = m_ctx.instantiate_mvars(lemma.m_proof);
expr new_proof = instantiate_mvars(lemma.m_proof);
m_new_instances.emplace_back(new_inst, new_proof);
}
@ -477,78 +901,88 @@ struct ematch_fn {
unsigned i = ps.size();
while (i > 1) {
--i;
s = cons(entry(Continue, ps[i], expr()), s);
s = cons(mk_continue(ps[i]), s);
}
return s;
}
void ematch_core(hinst_lemma const & lemma, state const & s, buffer<expr> const & p0_args, expr const & t) {
lean_trace(name({"debug", "ematch"}),
tout() << "ematch " << lemma.m_id << " [using] " << t << "\n";);
/* Ematch p =?= t with initial state init. p is the pattern, and t is a term.
The inital state init is used for multipatterns.
The given lemma is instantiated for each solution found.
The new instances are stored at m_new_instances. */
void main(hinst_lemma const & lemma, state const & init, expr const & p, expr const & t) {
type_context::tmp_mode_scope scope(m_ctx, lemma.m_num_uvars, lemma.m_num_mvars);
lean_assert(!has_idx_metavar(t));
clear_choice_stack();
m_state = s;
if (!match_args(m_state, p0_args, t))
m_state = init;
buffer<expr> p_args;
expr const & fn = get_app_args(p, p_args);
if (!m_ctx.is_def_eq(fn, get_app_fn(t)))
return;
if (!match_args(m_state, p_args, t))
return;
search(lemma);
}
void ematch(hinst_lemma const & lemma, multi_pattern const & mp, expr const & t) {
void ematch_term(hinst_lemma const & lemma, multi_pattern const & mp, expr const & t) {
buffer<expr> ps;
to_buffer(mp, ps);
/* TODO(Leo): use heuristic to select the pattern we will match first */
state init_state = mk_inital_state(ps);
buffer<expr> p0_args;
get_app_args(ps[0], p0_args);
ematch_core(lemma, init_state, p0_args, t);
main(lemma, init_state, ps[0], t);
}
void ematch(hinst_lemma const & lemma, expr const & t) {
void ematch_term(hinst_lemma const & lemma, expr const & t) {
for (multi_pattern const & mp : lemma.m_multi_patterns) {
ematch(lemma, mp, t);
ematch_term(lemma, mp, t);
}
}
void ematch_all_core(hinst_lemma const & lemma, buffer<expr> const & ps, bool filter) {
expr const & p0 = ps[0];
buffer<expr> p0_args;
expr const & f = get_app_args(p0, p0_args);
void ematch_terms_core(hinst_lemma const & lemma, buffer<expr> const & ps, bool filter) {
expr const & fn = get_app_fn(ps[0]);
unsigned gmt = m_cc.get_gmt();
state init_state = mk_inital_state(ps);
if (rb_expr_set const * s = m_em_state.get_app_map().find(head_index(f))) {
if (rb_expr_set const * s = m_em_state.get_app_map().find(head_index(fn))) {
s->for_each([&](expr const & t) {
if ((m_cc.is_congr_root(t) || m_cc.in_heterogeneous_eqc(t)) &&
(!filter || m_cc.get_mt(t) == gmt)) {
ematch_core(lemma, init_state, p0_args, t);
main(lemma, init_state, ps[0], t);
}
});
}
}
void ematch_all(hinst_lemma const & lemma, multi_pattern const & mp, bool filter) {
/* Match internalized terms in m_em_state with the given multipatterns.
If filter is true, then we use the term modification time information
stored in the congruence closure module. Only terms with
modification time (mt) >= the global modification time (gmt) are considered. */
void ematch_terms(hinst_lemma const & lemma, multi_pattern const & mp, bool filter) {
buffer<expr> ps;
to_buffer(mp, ps);
if (filter) {
for (unsigned i = 0; i < ps.size(); i++) {
std::swap(ps[0], ps[i]);
ematch_all_core(lemma, ps, filter);
ematch_terms_core(lemma, ps, filter);
std::swap(ps[0], ps[i]);
}
} else {
ematch_all_core(lemma, ps, filter);
ematch_terms_core(lemma, ps, filter);
}
}
void ematch_all(hinst_lemma const & lemma, bool filter) {
/* Match internalized terms in m_em_state with the given lemma. */
void ematch_terms(hinst_lemma const & lemma, bool filter) {
for (multi_pattern const & mp : lemma.m_multi_patterns) {
ematch_all(lemma, mp, filter);
ematch_terms(lemma, mp, filter);
}
}
void ematch_lemmas(hinst_lemmas const & lemmas, bool filter) {
/* Match internalized terms in m_em_state with the given lemmas. */
void ematch_using_lemmas(hinst_lemmas const & lemmas, bool filter) {
lemmas.for_each([&](hinst_lemma const & lemma) {
if (!m_em_state.max_instances_exceeded()) {
ematch_all(lemma, filter);
ematch_terms(lemma, filter);
}
});
}
@ -556,8 +990,8 @@ struct ematch_fn {
void operator()() {
if (m_em_state.max_instances_exceeded())
return;
ematch_lemmas(m_em_state.get_new_lemmas(), false);
ematch_lemmas(m_em_state.get_lemmas(), true);
ematch_using_lemmas(m_em_state.get_new_lemmas(), false);
ematch_using_lemmas(m_em_state.get_lemmas(), true);
m_em_state.m_lemmas.merge(m_em_state.m_new_lemmas);
m_em_state.m_new_lemmas = hinst_lemmas();
m_cc.inc_gmt();
@ -565,14 +999,17 @@ struct ematch_fn {
};
void ematch(type_context & ctx, ematch_state & s, congruence_closure & cc, hinst_lemma const & lemma, expr const & t, buffer<expr_pair> & result) {
ematch_fn(ctx, s, cc, result).ematch(lemma, t);
congruence_closure::state_scope scope(cc);
ematch_fn(ctx, s, cc, result).ematch_term(lemma, t);
}
void ematch(type_context & ctx, ematch_state & s, congruence_closure & cc, hinst_lemma const & lemma, bool filter, buffer<expr_pair> & result) {
ematch_fn(ctx, s, cc, result).ematch_all(lemma, filter);
congruence_closure::state_scope scope(cc);
ematch_fn(ctx, s, cc, result).ematch_terms(lemma, filter);
}
void ematch(type_context & ctx, ematch_state & s, congruence_closure & cc, buffer<expr_pair> & result) {
congruence_closure::state_scope scope(cc);
ematch_fn(ctx, s, cc, result)();
}

View file

@ -95,7 +95,6 @@ private:
ac_manager m_ac_manager;
buffer<expr_triple> m_todo;
optional<expr> is_ac(expr const & e);
expr convert(expr const & op, expr const & e, buffer<expr> & args);
bool internalize_var(expr const & e);
void insert_erase_R_occ(expr const & arg, expr const & lhs, bool in_lhs, bool is_insert);
@ -122,6 +121,7 @@ public:
void internalize(expr const & e, optional<expr> const & parent);
void add_eq(expr const & e1, expr const & e2);
optional<expr> is_ac(expr const & e);
format pp_term(formatter const & fmt, expr const & e) const {
return m_state.pp_term(fmt, e);