fix(library/equations_compiler/elim_match): constructor transition

This commit is contained in:
Leonardo de Moura 2016-08-21 21:47:36 -07:00
parent a4577901e8
commit 9c55ede671

View file

@ -16,6 +16,7 @@ Author: Leonardo de Moura
#include "library/app_builder.h"
#include "library/tactic/tactic_state.h"
#include "library/tactic/revert_tactic.h"
#include "library/tactic/clear_tactic.h"
#include "library/tactic/cases_tactic.h"
#include "library/tactic/intro_tactic.h"
#include "library/equations_compiler/equations.h"
@ -62,12 +63,13 @@ struct elim_match_fn {
};
struct lemma {
list<expr> m_vars;
expr m_eqn; /* equation (it might be conditional) */
expr m_proof;
local_context m_lctx;
list<expr> m_vars;
expr m_eqn; /* equation (it might be conditional) */
expr m_proof;
lemma() {}
lemma(list<expr> const & vars, expr const & eqn, expr const & proof):
m_vars(vars), m_eqn(eqn), m_proof(proof) {}
lemma(local_context const & lctx, list<expr> const & vars, expr const & eqn, expr const & proof):
m_lctx(lctx), m_vars(vars), m_eqn(eqn), m_proof(proof) {}
};
/** Result for the compilation procedure. */
@ -384,9 +386,9 @@ struct elim_match_fn {
void trace_lemmas(program const & P, char const * header, buffer<lemma> const & lemmas) {
trace_match_detail({
tout() << "[" << m_depth << "] " << header << " lemmas:\n";
auto pp_fn = mk_pp_ctx(P);
for (lemma const & L : lemmas) {
/* Replace function with its name. */
auto pp_fn = mk_pp_ctx(L.m_lctx);
expr tmp_eqn = update_eqn_lhs(L.m_eqn, P.m_nvars,
[&](buffer<expr> const & args) {
return mk_rev_app(mk_constant(P.m_fn_name), args);
@ -402,7 +404,30 @@ struct elim_match_fn {
result compile_skip(program const & P) {
trace_match(tout() << "skip transition\n";);
lean_unreachable();
program new_P;
new_P.m_fn_name = P.m_fn_name;
buffer<name> new_names;
optional<expr> aux_goal = intron(m_env, m_opts, m_mctx, P.m_goal, 1, new_names);
if (!aux_goal) throw_ill_formed_eqns();
lean_assert(new_names.size() == 1);
expr H = m_mctx.get_metavar_decl(*aux_goal)->get_context().get_local_decl(new_names[0])->mk_ref();
new_P.m_goal = *aux_goal; // clear(m_mctx, *aux_goal, H);
new_P.m_nvars = P.m_nvars - 1;
buffer<equation> new_eqns;
for (equation const & eqn : P.m_equations) {
equation new_eqn = eqn;
new_eqn.m_patterns = tail(eqn.m_patterns);
new_eqns.push_back(new_eqn);
}
new_P.m_equations = to_list(new_eqns);
result R = compile_core(new_P);
result new_R;
type_context ctx = mk_type_context(P);
if (m_lemmas) {
// TODO(Leo)
}
return new_R;
}
/* Update the equation left hand side
@ -448,6 +473,7 @@ struct elim_match_fn {
buffer<lemma> new_lemmas;
for (lemma const & L : R.m_lemmas) {
lemma new_L;
new_L.m_lctx = L.m_lctx;
new_L.m_vars = cons(x, L.m_vars);
new_L.m_eqn = update_eqn_for_variable_transition(L.m_eqn, new_P.m_nvars, new_R.m_code, x);
new_L.m_proof = L.m_proof;
@ -474,8 +500,7 @@ struct elim_match_fn {
(nil, L_1 := R_1)
(cons, (cons a b) L_2 := R_2)
(cons, (cons c d) L_3 := R_3)
(nil L_4 := R_4)
*/
(nil L_4 := R_4) */
void distribute_constructor_equations(list<equation> const & eqns, buffer<pair<name, equation>> & R) {
for (equation const & eqn : eqns) {
lean_assert(eqn.m_patterns);
@ -502,8 +527,7 @@ struct elim_match_fn {
The parameter \c field should be interpreted as a bit-mask here.
It says which constructor fields should be used. That is, "some" value means the field
should be considered.
*/
should be considered. */
list<equation> get_equations_for(name const & C, list<optional<name>> const & fields, name_map<name> const & renaming,
local_context const & lctx, buffer<pair<name, equation>> const & eqns) {
buffer<equation> R;
@ -554,20 +578,44 @@ struct elim_match_fn {
}
}
static list<bool> to_bitmask(list<optional<name>> const & ilist) {
return map2<bool>(ilist, [](optional<name> const & ilist) { return static_cast<bool>(ilist); });
}
/* Update an equation left-hand-side of the form
(f a_1 ... a_n b_1 ... b_m)
where n == nfields and n+m == arity, with
where n == number of true entries in mask, and n+m == arity, with
(new_fn (c a_1 ... a_n) b_1 ... b_m) */
expr update_eqn_for_constructor_transition(expr const & eqn, unsigned arity, expr const & new_fn, expr const & c, unsigned nfields) {
return update_eqn_lhs(eqn, arity, [&](buffer<expr> & args) {
lean_assert(args.size() >= nfields);
expr c_app = mk_rev_app(c, nfields, args.end() - nfields);
args.shrink(args.size() - nfields);
args.push_back(c_app);
return mk_rev_app(new_fn, args);
(new_fn (c I_params i ... a_1 ... a_n) b_1 ... b_m)
if there are false entries in mask, we need to infer any missing arguments 'i'. */
expr update_eqn_for_constructor_transition(lemma const & L, list<bool> const & mask,
unsigned arity, expr const & new_fn,
name const & c_name, buffer<expr> const & I_params) {
type_context ctx = mk_type_context(L.m_lctx);
return update_eqn_lhs(L.m_eqn, arity, [&](buffer<expr> & args) {
std::reverse(args.begin(), args.end());
buffer<bool> c_mask;
buffer<expr> c_args;
/* Add I_params */
for (expr const & p : I_params) {
c_mask.push_back(true);
c_args.push_back(p);
}
/* Add constructor fields */
unsigned i = 0;
for (bool b : mask) {
/* Remark: b is false only for indexed families. */
c_mask.push_back(b);
if (b) {
c_args.push_back(args[i]);
i++;
}
}
expr c_app = mk_app(ctx, c_name, c_mask.size(), c_mask.data(), c_args.data());
return mk_app(mk_app(new_fn, c_app), args.size() - i, args.data() + i);
});
}
@ -601,10 +649,11 @@ struct elim_match_fn {
distribute_constructor_equations(P.m_equations, equations_by_constructor);
/* For each (reachable) case, we invoke compile recursively, and we store
- name of the constructor
- number of fields of this constructor
- bitmask for which fields were introduced. The lenght of the bitmask is equal
to the head(ilist). The value is true iff the corresponding element in head(ilist) is not none.
- "arity" of the auxiliary program being used in the recursive call
- result of the compilation for this auxiliary function. */
buffer<std::tuple<name, unsigned, unsigned, result>> result_by_constructor;
buffer<std::tuple<name, list<bool>, unsigned, result>> result_by_constructor;
while (new_goals) {
lean_assert(new_goal_cnames && ilist && rlist);
program new_P;
@ -613,18 +662,18 @@ struct elim_match_fn {
/* Revert constructor fields (which have not been eliminated by dependent pattern matching). */
buffer<expr> to_revert;
to_buffer_local(new_goal, head(ilist), to_revert);
unsigned to_revert_size = to_revert.size();
unsigned nfields = to_revert_size;
expr aux_mvar2 = revert(m_env, m_opts, m_mctx, head(new_goals), to_revert);
unsigned to_revert_size = to_revert.size();
unsigned num_intro_fields = to_revert_size;
expr aux_mvar2 = revert(m_env, m_opts, m_mctx, head(new_goals), to_revert);
lean_assert(to_revert.size() == to_revert_size);
new_P.m_goal = aux_mvar2;
/* The arity of the auxiliary program is the arity of the original program
- 1 (we consumed one argument in this step) and + nfields (we added nfields new arguments). */
new_P.m_nvars = P.m_nvars - 1 + nfields;
- 1 (we consumed one argument in this step) and + number of introduced constructor fields. */
new_P.m_nvars = P.m_nvars - 1 + num_intro_fields;
new_P.m_equations = get_equations_for(head(new_goal_cnames), head(ilist), head(rlist),
get_local_context(aux_mvar2), equations_by_constructor);
result new_R = compile_core(new_P);
result_by_constructor.emplace_back(head(new_goal_cnames), nfields, new_P.m_nvars, new_R);
result_by_constructor.emplace_back(head(new_goal_cnames), to_bitmask(head(ilist)), new_P.m_nvars, new_R);
new_goals = tail(new_goals);
new_goal_cnames = tail(new_goal_cnames);
@ -639,16 +688,16 @@ struct elim_match_fn {
buffer<expr> I_params;
levels I_lvls = get_inductive_levels_and_params(ctx, I, I_params);
buffer<lemma> new_lemmas;
for (std::tuple<name, unsigned, unsigned, result> const & entry : result_by_constructor) {
for (std::tuple<name, list<bool>, unsigned, result> const & entry : result_by_constructor) {
name const & cname = std::get<0>(entry); /* constructor name */
unsigned nfields = std::get<1>(entry);
list<bool> mask = std::get<1>(entry); /* bitmask indicating which constructor fields have been introduced by cases-tactic. */
unsigned arity = std::get<2>(entry);
result const & Rc = std::get<3>(entry);
expr c = mk_app(mk_constant(cname, I_lvls), I_params);
for (lemma const & L : Rc.m_lemmas) {
lemma new_L;
new_L.m_lctx = L.m_lctx;
new_L.m_vars = L.m_vars;
new_L.m_eqn = update_eqn_for_constructor_transition(L.m_eqn, arity, new_R.m_code, c, nfields);
new_L.m_eqn = update_eqn_for_constructor_transition(L, mask, arity, new_R.m_code, cname, I_params);
new_L.m_proof = L.m_proof;
new_lemmas.push_back(new_L);
}
@ -690,7 +739,7 @@ struct elim_match_fn {
type_context ctx = mk_type_context(get_local_context(P));
expr eq = mk_eq(ctx, rhs, rhs);
expr H = mk_eq_refl(ctx, rhs);
R.m_lemmas = to_list(lemma(list<expr>(), eq, H));
R.m_lemmas = to_list(lemma(ctx.lctx(), list<expr>(), eq, H));
}
return R;
}
@ -738,6 +787,17 @@ struct elim_match_fn {
return R;
}
void abstract_eqns_vars(list<lemma> const & Ls, buffer<expr_pair> & R) {
for (lemma const & L : Ls) {
type_context ctx = mk_type_context(L.m_lctx);
buffer<expr> vars;
to_buffer(L.m_vars, vars);
expr e = ctx.mk_pi(vars, L.m_eqn);
expr H = ctx.mk_lambda(vars, L.m_proof);
R.emplace_back(e, H);
}
}
expr operator()(local_context const & lctx, expr const & eqns) {
lean_assert(equations_num_fns(eqns) == 1);
DEBUG_CODE({
@ -748,6 +808,13 @@ struct elim_match_fn {
program P = mk_program(lctx, eqns);
result R = compile(P);
#if 0
if (m_lemmas) {
buffer<expr_pair> Hs;
abstract_eqns_vars(R.m_lemmas, Hs);
}
#endif
lean_unreachable();
}
};