fix(library/equations_compiler/elim_match): constructor transition
This commit is contained in:
parent
a4577901e8
commit
9c55ede671
1 changed files with 100 additions and 33 deletions
|
|
@ -16,6 +16,7 @@ Author: Leonardo de Moura
|
|||
#include "library/app_builder.h"
|
||||
#include "library/tactic/tactic_state.h"
|
||||
#include "library/tactic/revert_tactic.h"
|
||||
#include "library/tactic/clear_tactic.h"
|
||||
#include "library/tactic/cases_tactic.h"
|
||||
#include "library/tactic/intro_tactic.h"
|
||||
#include "library/equations_compiler/equations.h"
|
||||
|
|
@ -62,12 +63,13 @@ struct elim_match_fn {
|
|||
};
|
||||
|
||||
struct lemma {
|
||||
list<expr> m_vars;
|
||||
expr m_eqn; /* equation (it might be conditional) */
|
||||
expr m_proof;
|
||||
local_context m_lctx;
|
||||
list<expr> m_vars;
|
||||
expr m_eqn; /* equation (it might be conditional) */
|
||||
expr m_proof;
|
||||
lemma() {}
|
||||
lemma(list<expr> const & vars, expr const & eqn, expr const & proof):
|
||||
m_vars(vars), m_eqn(eqn), m_proof(proof) {}
|
||||
lemma(local_context const & lctx, list<expr> const & vars, expr const & eqn, expr const & proof):
|
||||
m_lctx(lctx), m_vars(vars), m_eqn(eqn), m_proof(proof) {}
|
||||
};
|
||||
|
||||
/** Result for the compilation procedure. */
|
||||
|
|
@ -384,9 +386,9 @@ struct elim_match_fn {
|
|||
void trace_lemmas(program const & P, char const * header, buffer<lemma> const & lemmas) {
|
||||
trace_match_detail({
|
||||
tout() << "[" << m_depth << "] " << header << " lemmas:\n";
|
||||
auto pp_fn = mk_pp_ctx(P);
|
||||
for (lemma const & L : lemmas) {
|
||||
/* Replace function with its name. */
|
||||
auto pp_fn = mk_pp_ctx(L.m_lctx);
|
||||
expr tmp_eqn = update_eqn_lhs(L.m_eqn, P.m_nvars,
|
||||
[&](buffer<expr> const & args) {
|
||||
return mk_rev_app(mk_constant(P.m_fn_name), args);
|
||||
|
|
@ -402,7 +404,30 @@ struct elim_match_fn {
|
|||
|
||||
result compile_skip(program const & P) {
|
||||
trace_match(tout() << "skip transition\n";);
|
||||
lean_unreachable();
|
||||
program new_P;
|
||||
new_P.m_fn_name = P.m_fn_name;
|
||||
buffer<name> new_names;
|
||||
optional<expr> aux_goal = intron(m_env, m_opts, m_mctx, P.m_goal, 1, new_names);
|
||||
if (!aux_goal) throw_ill_formed_eqns();
|
||||
lean_assert(new_names.size() == 1);
|
||||
expr H = m_mctx.get_metavar_decl(*aux_goal)->get_context().get_local_decl(new_names[0])->mk_ref();
|
||||
new_P.m_goal = *aux_goal; // clear(m_mctx, *aux_goal, H);
|
||||
new_P.m_nvars = P.m_nvars - 1;
|
||||
buffer<equation> new_eqns;
|
||||
for (equation const & eqn : P.m_equations) {
|
||||
equation new_eqn = eqn;
|
||||
new_eqn.m_patterns = tail(eqn.m_patterns);
|
||||
new_eqns.push_back(new_eqn);
|
||||
}
|
||||
new_P.m_equations = to_list(new_eqns);
|
||||
result R = compile_core(new_P);
|
||||
result new_R;
|
||||
type_context ctx = mk_type_context(P);
|
||||
if (m_lemmas) {
|
||||
// TODO(Leo)
|
||||
|
||||
}
|
||||
return new_R;
|
||||
}
|
||||
|
||||
/* Update the equation left hand side
|
||||
|
|
@ -448,6 +473,7 @@ struct elim_match_fn {
|
|||
buffer<lemma> new_lemmas;
|
||||
for (lemma const & L : R.m_lemmas) {
|
||||
lemma new_L;
|
||||
new_L.m_lctx = L.m_lctx;
|
||||
new_L.m_vars = cons(x, L.m_vars);
|
||||
new_L.m_eqn = update_eqn_for_variable_transition(L.m_eqn, new_P.m_nvars, new_R.m_code, x);
|
||||
new_L.m_proof = L.m_proof;
|
||||
|
|
@ -474,8 +500,7 @@ struct elim_match_fn {
|
|||
(nil, L_1 := R_1)
|
||||
(cons, (cons a b) L_2 := R_2)
|
||||
(cons, (cons c d) L_3 := R_3)
|
||||
(nil L_4 := R_4)
|
||||
*/
|
||||
(nil L_4 := R_4) */
|
||||
void distribute_constructor_equations(list<equation> const & eqns, buffer<pair<name, equation>> & R) {
|
||||
for (equation const & eqn : eqns) {
|
||||
lean_assert(eqn.m_patterns);
|
||||
|
|
@ -502,8 +527,7 @@ struct elim_match_fn {
|
|||
|
||||
The parameter \c field should be interpreted as a bit-mask here.
|
||||
It says which constructor fields should be used. That is, "some" value means the field
|
||||
should be considered.
|
||||
*/
|
||||
should be considered. */
|
||||
list<equation> get_equations_for(name const & C, list<optional<name>> const & fields, name_map<name> const & renaming,
|
||||
local_context const & lctx, buffer<pair<name, equation>> const & eqns) {
|
||||
buffer<equation> R;
|
||||
|
|
@ -554,20 +578,44 @@ struct elim_match_fn {
|
|||
}
|
||||
}
|
||||
|
||||
static list<bool> to_bitmask(list<optional<name>> const & ilist) {
|
||||
return map2<bool>(ilist, [](optional<name> const & ilist) { return static_cast<bool>(ilist); });
|
||||
}
|
||||
|
||||
/* Update an equation left-hand-side of the form
|
||||
|
||||
(f a_1 ... a_n b_1 ... b_m)
|
||||
|
||||
where n == nfields and n+m == arity, with
|
||||
where n == number of true entries in mask, and n+m == arity, with
|
||||
|
||||
(new_fn (c a_1 ... a_n) b_1 ... b_m) */
|
||||
expr update_eqn_for_constructor_transition(expr const & eqn, unsigned arity, expr const & new_fn, expr const & c, unsigned nfields) {
|
||||
return update_eqn_lhs(eqn, arity, [&](buffer<expr> & args) {
|
||||
lean_assert(args.size() >= nfields);
|
||||
expr c_app = mk_rev_app(c, nfields, args.end() - nfields);
|
||||
args.shrink(args.size() - nfields);
|
||||
args.push_back(c_app);
|
||||
return mk_rev_app(new_fn, args);
|
||||
(new_fn (c I_params i ... a_1 ... a_n) b_1 ... b_m)
|
||||
|
||||
if there are false entries in mask, we need to infer any missing arguments 'i'. */
|
||||
expr update_eqn_for_constructor_transition(lemma const & L, list<bool> const & mask,
|
||||
unsigned arity, expr const & new_fn,
|
||||
name const & c_name, buffer<expr> const & I_params) {
|
||||
type_context ctx = mk_type_context(L.m_lctx);
|
||||
return update_eqn_lhs(L.m_eqn, arity, [&](buffer<expr> & args) {
|
||||
std::reverse(args.begin(), args.end());
|
||||
buffer<bool> c_mask;
|
||||
buffer<expr> c_args;
|
||||
/* Add I_params */
|
||||
for (expr const & p : I_params) {
|
||||
c_mask.push_back(true);
|
||||
c_args.push_back(p);
|
||||
}
|
||||
/* Add constructor fields */
|
||||
unsigned i = 0;
|
||||
for (bool b : mask) {
|
||||
/* Remark: b is false only for indexed families. */
|
||||
c_mask.push_back(b);
|
||||
if (b) {
|
||||
c_args.push_back(args[i]);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
expr c_app = mk_app(ctx, c_name, c_mask.size(), c_mask.data(), c_args.data());
|
||||
return mk_app(mk_app(new_fn, c_app), args.size() - i, args.data() + i);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -601,10 +649,11 @@ struct elim_match_fn {
|
|||
distribute_constructor_equations(P.m_equations, equations_by_constructor);
|
||||
/* For each (reachable) case, we invoke compile recursively, and we store
|
||||
- name of the constructor
|
||||
- number of fields of this constructor
|
||||
- bitmask for which fields were introduced. The lenght of the bitmask is equal
|
||||
to the head(ilist). The value is true iff the corresponding element in head(ilist) is not none.
|
||||
- "arity" of the auxiliary program being used in the recursive call
|
||||
- result of the compilation for this auxiliary function. */
|
||||
buffer<std::tuple<name, unsigned, unsigned, result>> result_by_constructor;
|
||||
buffer<std::tuple<name, list<bool>, unsigned, result>> result_by_constructor;
|
||||
while (new_goals) {
|
||||
lean_assert(new_goal_cnames && ilist && rlist);
|
||||
program new_P;
|
||||
|
|
@ -613,18 +662,18 @@ struct elim_match_fn {
|
|||
/* Revert constructor fields (which have not been eliminated by dependent pattern matching). */
|
||||
buffer<expr> to_revert;
|
||||
to_buffer_local(new_goal, head(ilist), to_revert);
|
||||
unsigned to_revert_size = to_revert.size();
|
||||
unsigned nfields = to_revert_size;
|
||||
expr aux_mvar2 = revert(m_env, m_opts, m_mctx, head(new_goals), to_revert);
|
||||
unsigned to_revert_size = to_revert.size();
|
||||
unsigned num_intro_fields = to_revert_size;
|
||||
expr aux_mvar2 = revert(m_env, m_opts, m_mctx, head(new_goals), to_revert);
|
||||
lean_assert(to_revert.size() == to_revert_size);
|
||||
new_P.m_goal = aux_mvar2;
|
||||
/* The arity of the auxiliary program is the arity of the original program
|
||||
- 1 (we consumed one argument in this step) and + nfields (we added nfields new arguments). */
|
||||
new_P.m_nvars = P.m_nvars - 1 + nfields;
|
||||
- 1 (we consumed one argument in this step) and + number of introduced constructor fields. */
|
||||
new_P.m_nvars = P.m_nvars - 1 + num_intro_fields;
|
||||
new_P.m_equations = get_equations_for(head(new_goal_cnames), head(ilist), head(rlist),
|
||||
get_local_context(aux_mvar2), equations_by_constructor);
|
||||
result new_R = compile_core(new_P);
|
||||
result_by_constructor.emplace_back(head(new_goal_cnames), nfields, new_P.m_nvars, new_R);
|
||||
result_by_constructor.emplace_back(head(new_goal_cnames), to_bitmask(head(ilist)), new_P.m_nvars, new_R);
|
||||
|
||||
new_goals = tail(new_goals);
|
||||
new_goal_cnames = tail(new_goal_cnames);
|
||||
|
|
@ -639,16 +688,16 @@ struct elim_match_fn {
|
|||
buffer<expr> I_params;
|
||||
levels I_lvls = get_inductive_levels_and_params(ctx, I, I_params);
|
||||
buffer<lemma> new_lemmas;
|
||||
for (std::tuple<name, unsigned, unsigned, result> const & entry : result_by_constructor) {
|
||||
for (std::tuple<name, list<bool>, unsigned, result> const & entry : result_by_constructor) {
|
||||
name const & cname = std::get<0>(entry); /* constructor name */
|
||||
unsigned nfields = std::get<1>(entry);
|
||||
list<bool> mask = std::get<1>(entry); /* bitmask indicating which constructor fields have been introduced by cases-tactic. */
|
||||
unsigned arity = std::get<2>(entry);
|
||||
result const & Rc = std::get<3>(entry);
|
||||
expr c = mk_app(mk_constant(cname, I_lvls), I_params);
|
||||
for (lemma const & L : Rc.m_lemmas) {
|
||||
lemma new_L;
|
||||
new_L.m_lctx = L.m_lctx;
|
||||
new_L.m_vars = L.m_vars;
|
||||
new_L.m_eqn = update_eqn_for_constructor_transition(L.m_eqn, arity, new_R.m_code, c, nfields);
|
||||
new_L.m_eqn = update_eqn_for_constructor_transition(L, mask, arity, new_R.m_code, cname, I_params);
|
||||
new_L.m_proof = L.m_proof;
|
||||
new_lemmas.push_back(new_L);
|
||||
}
|
||||
|
|
@ -690,7 +739,7 @@ struct elim_match_fn {
|
|||
type_context ctx = mk_type_context(get_local_context(P));
|
||||
expr eq = mk_eq(ctx, rhs, rhs);
|
||||
expr H = mk_eq_refl(ctx, rhs);
|
||||
R.m_lemmas = to_list(lemma(list<expr>(), eq, H));
|
||||
R.m_lemmas = to_list(lemma(ctx.lctx(), list<expr>(), eq, H));
|
||||
}
|
||||
return R;
|
||||
}
|
||||
|
|
@ -738,6 +787,17 @@ struct elim_match_fn {
|
|||
return R;
|
||||
}
|
||||
|
||||
void abstract_eqns_vars(list<lemma> const & Ls, buffer<expr_pair> & R) {
|
||||
for (lemma const & L : Ls) {
|
||||
type_context ctx = mk_type_context(L.m_lctx);
|
||||
buffer<expr> vars;
|
||||
to_buffer(L.m_vars, vars);
|
||||
expr e = ctx.mk_pi(vars, L.m_eqn);
|
||||
expr H = ctx.mk_lambda(vars, L.m_proof);
|
||||
R.emplace_back(e, H);
|
||||
}
|
||||
}
|
||||
|
||||
expr operator()(local_context const & lctx, expr const & eqns) {
|
||||
lean_assert(equations_num_fns(eqns) == 1);
|
||||
DEBUG_CODE({
|
||||
|
|
@ -748,6 +808,13 @@ struct elim_match_fn {
|
|||
program P = mk_program(lctx, eqns);
|
||||
result R = compile(P);
|
||||
|
||||
#if 0
|
||||
if (m_lemmas) {
|
||||
buffer<expr_pair> Hs;
|
||||
abstract_eqns_vars(R.m_lemmas, Hs);
|
||||
}
|
||||
#endif
|
||||
|
||||
lean_unreachable();
|
||||
}
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue