feat(kernel/inductive): postprocess recursors and their rules
This commit is contained in:
parent
498cfa84fd
commit
47bc71f4fa
2 changed files with 138 additions and 28 deletions
|
|
@ -272,8 +272,8 @@ public:
|
|||
inductive_val & operator=(inductive_val const & other) { object_ref::operator=(other); return *this; }
|
||||
inductive_val & operator=(inductive_val && other) { object_ref::operator=(other); return *this; }
|
||||
constant_val const & to_constant_val() const { return static_cast<constant_val const &>(cnstr_obj_ref(*this, 0)); }
|
||||
nat const & get_nparams() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 1)); }
|
||||
nat const & get_nindices() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 2)); }
|
||||
unsigned get_nparams() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 1)).get_small_value(); }
|
||||
unsigned get_nindices() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 2)).get_small_value(); }
|
||||
names const & get_all() const { return static_cast<names const &>(cnstr_obj_ref(*this, 3)); }
|
||||
names const & get_cnstrs() const { return static_cast<names const &>(cnstr_obj_ref(*this, 4)); }
|
||||
bool is_rec() const { return cnstr_scalar<unsigned char>(raw(), sizeof(object*)*5) != 0; }
|
||||
|
|
@ -295,7 +295,7 @@ public:
|
|||
constructor_val & operator=(constructor_val && other) { object_ref::operator=(other); return *this; }
|
||||
constant_val const & to_constant_val() const { return static_cast<constant_val const &>(cnstr_obj_ref(*this, 0)); }
|
||||
name const & get_induct() const { return static_cast<name const &>(cnstr_obj_ref(*this, 1)); }
|
||||
nat const & get_nparams() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 2)); }
|
||||
unsigned get_nparams() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 2)).get_small_value(); }
|
||||
bool is_meta() const { return cnstr_scalar<unsigned char>(raw(), sizeof(object*)*3) != 0; }
|
||||
};
|
||||
|
||||
|
|
@ -313,7 +313,7 @@ public:
|
|||
recursor_rule & operator=(recursor_rule const & other) { object_ref::operator=(other); return *this; }
|
||||
recursor_rule & operator=(recursor_rule && other) { object_ref::operator=(other); return *this; }
|
||||
name const & get_constructor() const { return static_cast<name const &>(cnstr_obj_ref(*this, 0)); }
|
||||
nat const & get_nfields() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 1)); }
|
||||
unsigned get_nfields() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 1)).get_small_value(); }
|
||||
expr const & get_rhs() const { return static_cast<expr const &>(cnstr_obj_ref(*this, 2)); }
|
||||
};
|
||||
|
||||
|
|
@ -341,10 +341,10 @@ public:
|
|||
recursor_val & operator=(recursor_val && other) { object_ref::operator=(other); return *this; }
|
||||
constant_val const & to_constant_val() const { return static_cast<constant_val const &>(cnstr_obj_ref(*this, 0)); }
|
||||
names const & get_all() const { return static_cast<names const &>(cnstr_obj_ref(*this, 1)); }
|
||||
nat const & get_nparams() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 2)); }
|
||||
nat const & get_nindices() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 3)); }
|
||||
nat const & get_nmotives() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 4)); }
|
||||
nat const & get_nminors() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 5)); }
|
||||
unsigned get_nparams() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 2)).get_small_value(); }
|
||||
unsigned get_nindices() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 3)).get_small_value(); }
|
||||
unsigned get_nmotives() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 4)).get_small_value(); }
|
||||
unsigned get_nminors() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 5)).get_small_value(); }
|
||||
recursor_rules const & get_rules() const { return static_cast<recursor_rules const &>(cnstr_obj_ref(*this, 6)); }
|
||||
bool is_k() const { return cnstr_scalar<unsigned char>(raw(), sizeof(object*)*7) != 0; }
|
||||
bool is_meta() const { return cnstr_scalar<unsigned char>(raw(), sizeof(object*)*7 + 1) != 0; }
|
||||
|
|
|
|||
|
|
@ -661,27 +661,71 @@ struct elim_nested_inductive_result {
|
|||
}
|
||||
}
|
||||
|
||||
expr restore_nested(expr e) const {
|
||||
/* If `c` is an constructor name associated with an auxiliary inductive type, then return the
|
||||
nested inductive associated with it and the name of its inductive type. Return none. */
|
||||
optional<pair<expr, name>> get_nested_if_aux_constructor(environment const & aux_env, name const & c) const {
|
||||
optional<constant_info> info = aux_env.find(c);
|
||||
if (!info) return optional<pair<expr, name>>();
|
||||
name auxI_name = info->to_constructor_val().get_induct();
|
||||
expr const * nested = m_aux2nested.find(auxI_name);
|
||||
if (!nested) return optional<pair<expr, name>>();
|
||||
return optional<pair<expr, name>>(*nested, auxI_name);
|
||||
}
|
||||
|
||||
name restore_constructor_name(environment const & aux_env, name const & cnstr_name) const {
|
||||
optional<pair<expr, name>> p = get_nested_if_aux_constructor(aux_env, cnstr_name);
|
||||
lean_assert(p);
|
||||
expr const & I = get_app_fn(p->first);
|
||||
lean_assert(is_constant(I));
|
||||
return cnstr_name.replace_prefix(p->second, const_name(I));
|
||||
}
|
||||
|
||||
expr restore_nested(expr e, environment const & aux_env, name_map<name> const & aux_rec_name_map = name_map<name>()) const {
|
||||
lean_assert(is_pi(e) || is_lambda(e));
|
||||
local_ctx lctx;
|
||||
buffer<expr> As;
|
||||
bool pi = is_pi(e);
|
||||
for (unsigned i = 0; i < m_params.size(); i++) {
|
||||
lean_assert(is_pi(e));
|
||||
lean_assert(is_pi(e) || is_lambda(e));
|
||||
As.push_back(lctx.mk_local_decl(binding_name(e), binding_domain(e), binding_info(e)));
|
||||
e = instantiate(binding_body(e), As.back());
|
||||
}
|
||||
e = replace(e, [&](expr const & t, unsigned) {
|
||||
if (!is_app(t)) return none_expr();
|
||||
if (is_constant(t)) {
|
||||
if (name const * rec_name = aux_rec_name_map.find(const_name(t))) {
|
||||
return some_expr(mk_constant(*rec_name, const_levels(t)));
|
||||
}
|
||||
}
|
||||
expr const & fn = get_app_fn(t);
|
||||
if (!is_constant(fn)) return none_expr();
|
||||
expr const * nested = m_aux2nested.find(const_name(fn));
|
||||
if (!nested) return none_expr();
|
||||
buffer<expr> args;
|
||||
get_app_args(t, args);
|
||||
lean_assert(args.size() >= m_params.size());
|
||||
expr new_t = instantiate_rev(abstract(*nested, m_params.size(), m_params.data()), As.size(), As.data());
|
||||
return some_expr(mk_app(new_t, args.size() - m_params.size(), args.data() + m_params.size()));
|
||||
if (is_constant(fn)) {
|
||||
if (expr const * nested = m_aux2nested.find(const_name(fn))) {
|
||||
buffer<expr> args;
|
||||
get_app_args(t, args);
|
||||
lean_assert(args.size() >= m_params.size());
|
||||
expr new_t = instantiate_rev(abstract(*nested, m_params.size(), m_params.data()), As.size(), As.data());
|
||||
return some_expr(mk_app(new_t, args.size() - m_params.size(), args.data() + m_params.size()));
|
||||
}
|
||||
if (optional<pair<expr, name>> r = get_nested_if_aux_constructor(aux_env, const_name(fn))) {
|
||||
expr nested = r->first;
|
||||
name auxI_name = r->second;
|
||||
/* `t` is a constructor-application of an auxiliary inductive type */
|
||||
buffer<expr> args;
|
||||
get_app_args(t, args);
|
||||
lean_assert(args.size() >= m_params.size());
|
||||
expr new_nested = instantiate_rev(abstract(nested, m_params.size(), m_params.data()), As.size(), As.data());
|
||||
buffer<expr> I_args;
|
||||
expr I = get_app_args(new_nested, I_args);
|
||||
lean_assert(I_args.size() == m_params.size());
|
||||
lean_assert(is_constant(I));
|
||||
name new_fn_name = const_name(fn).replace_prefix(auxI_name, const_name(I));
|
||||
expr new_fn = mk_constant(new_fn_name, const_levels(I));
|
||||
expr new_t = mk_app(mk_app(new_fn, I_args), args.size() - I_args.size(), args.data() + I_args.size());
|
||||
return some_expr(new_t);
|
||||
}
|
||||
}
|
||||
return none_expr();
|
||||
});
|
||||
return lctx.mk_pi(As, e);
|
||||
return pi ? lctx.mk_pi(As, e) : lctx.mk_lambda(As, e);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -737,11 +781,11 @@ struct elim_nested_inductive_fn {
|
|||
if (!info || !info->is_inductive()) return optional<inductive_val>();
|
||||
buffer<expr> args;
|
||||
get_app_args(e, args);
|
||||
nat const & nparams = info->to_inductive_val().get_nparams();
|
||||
unsigned nparams = info->to_inductive_val().get_nparams();
|
||||
if (nparams > args.size()) return optional<inductive_val>();
|
||||
bool is_nested = false;
|
||||
bool loose_bvars = false;
|
||||
for (unsigned i = 0; i < nparams.get_small_value(); i++) {
|
||||
for (unsigned i = 0; i < nparams; i++) {
|
||||
if (has_loose_bvars(args[i])) {
|
||||
loose_bvars = true;
|
||||
}
|
||||
|
|
@ -781,7 +825,7 @@ struct elim_nested_inductive_fn {
|
|||
name const & I_name = const_name(fn);
|
||||
levels const & I_lvls = const_levels(fn);
|
||||
lean_assert(I_val->get_nparams() <= args.size());
|
||||
unsigned I_nparams = I_val->get_nparams().get_small_value();
|
||||
unsigned I_nparams = I_val->get_nparams();
|
||||
expr IAs = mk_app(fn, I_nparams, args.data()); /* `I As` */
|
||||
/* Check whether we have already created an auxiliary inductive_type for `I As` */
|
||||
optional<name> auxI_name;
|
||||
|
|
@ -888,6 +932,43 @@ struct elim_nested_inductive_fn {
|
|||
}
|
||||
};
|
||||
|
||||
/* Given the auxiliary environment `aux_env` generated by processing the auxiliary mutual declaration,
|
||||
and the original declaration `d`. This function return a pair `(aux_rec_names, aux_rec_name_map)`
|
||||
where `aux_rec_names` contains the recursor names associated to auxiliary inductive types used to
|
||||
eliminated nested inductive occurrences.
|
||||
The mapping `aux_rec_name_map` contains an entry `(aux_rec_name -> rec_name)` for each
|
||||
element in `aux_rec_names`. It provides the new names for these recursors.
|
||||
|
||||
We compute the new recursor names using the first inductive datatype in the original declaration `d`,
|
||||
and the suffice `.rec_<idx>`. */
|
||||
static pair<names, name_map<name>> mk_aux_rec_name_map(environment const & aux_env, inductive_decl const & d) {
|
||||
unsigned ntypes = length(d.get_types());
|
||||
lean_assert(ntypes > 0);
|
||||
inductive_type const & main_type = head(d.get_types());
|
||||
name const & main_name = main_type.get_name();
|
||||
constant_info main_info = aux_env.get(main_name);
|
||||
names const & all_names = main_info.to_inductive_val().get_all();
|
||||
/* This function is only called if we have created auxiliary inductive types when eliminating
|
||||
the nested inductives. */
|
||||
lean_assert(length(all_names) > ntypes);
|
||||
/* Remark: we use the `main_name` to declarate the auxiliary recursors as: <main_name>.rec_1, <main_name>.rec_2, ...
|
||||
This is a little bit asymmetrical if `d` is a mutual declaration, but it makes sure we have simple names. */
|
||||
buffer<name> old_rec_names;
|
||||
name_map<name> rec_map;
|
||||
unsigned i = 0;
|
||||
unsigned next_idx = 1;
|
||||
for (name const & ind_name : all_names) {
|
||||
if (i >= ntypes) {
|
||||
old_rec_names.push_back(mk_rec_name(ind_name));
|
||||
name new_rec_name = mk_rec_name(main_name).append_after(next_idx);
|
||||
next_idx++;
|
||||
rec_map.insert(old_rec_names.back(), new_rec_name);
|
||||
}
|
||||
i++;
|
||||
}
|
||||
return mk_pair(names(old_rec_names), rec_map);
|
||||
}
|
||||
|
||||
environment environment::add_inductive(declaration const & d) const {
|
||||
elim_nested_inductive_result res = elim_nested_inductive_fn(*this, d)();
|
||||
environment aux_env = add_inductive_fn(*this, inductive_decl(res.m_aux_decl))();
|
||||
|
|
@ -896,9 +977,35 @@ environment environment::add_inductive(declaration const & d) const {
|
|||
return aux_env;
|
||||
} else {
|
||||
/* Restore nested inductives. */
|
||||
environment new_env = *this;
|
||||
inductive_decl ind_d(d);
|
||||
names all_ind_names = get_all_inductive_names(ind_d);
|
||||
names aux_rec_names; name_map<name> aux_rec_name_map;
|
||||
std::tie(aux_rec_names, aux_rec_name_map) = mk_aux_rec_name_map(aux_env, d);
|
||||
environment new_env = *this;
|
||||
auto process_rec = [&](name const & rec_name) {
|
||||
name new_rec_name = rec_name;
|
||||
if (name const * new_name = aux_rec_name_map.find(rec_name))
|
||||
new_rec_name = *new_name;
|
||||
constant_info rec_info = aux_env.get(rec_name);
|
||||
expr new_rec_type = res.restore_nested(rec_info.get_type(), aux_env, aux_rec_name_map);
|
||||
recursor_val rec_val = rec_info.to_recursor_val();
|
||||
buffer<recursor_rule> new_rules;
|
||||
for (recursor_rule const & rule : rec_val.get_rules()) {
|
||||
expr new_rhs = res.restore_nested(rule.get_rhs(), aux_env, aux_rec_name_map);
|
||||
name cnstr_name = rule.get_constructor();
|
||||
name new_cnstr_name = cnstr_name;
|
||||
if (new_rec_name != rec_name) {
|
||||
/* We need to fix the constructor name */
|
||||
new_cnstr_name = res.restore_constructor_name(aux_env, cnstr_name);
|
||||
}
|
||||
new_rules.push_back(recursor_rule(new_cnstr_name, rule.get_nfields(), new_rhs));
|
||||
}
|
||||
new_env.check_name(new_rec_name);
|
||||
new_env.add_core(constant_info(recursor_val(new_rec_name, rec_info.get_univ_params(), new_rec_type,
|
||||
all_ind_names, rec_val.get_nparams(), rec_val.get_nindices(), rec_val.get_nmotives(),
|
||||
rec_val.get_nminors(), recursor_rules(new_rules),
|
||||
rec_val.is_k(), rec_val.is_meta())));
|
||||
};
|
||||
for (inductive_type const & ind_type : ind_d.get_types()) {
|
||||
constant_info ind_info = aux_env.get(ind_type.get_name());
|
||||
inductive_val ind_val = ind_info.to_inductive_val();
|
||||
|
|
@ -906,17 +1013,20 @@ environment environment::add_inductive(declaration const & d) const {
|
|||
|
||||
Remark: if we decide to store the recursor names, we will also need to fix it. */
|
||||
new_env.add_core(constant_info(inductive_val(ind_info.get_name(), ind_info.get_univ_params(), ind_info.get_type(),
|
||||
ind_val.get_nparams().get_small_value(), ind_val.get_nindices().get_small_value(),
|
||||
ind_val.get_nparams(), ind_val.get_nindices(),
|
||||
all_ind_names, ind_val.get_cnstrs(), ind_val.is_rec(), ind_val.is_meta())));
|
||||
for (name const & cnstr_name : ind_val.get_cnstrs()) {
|
||||
constant_info cnstr_info = aux_env.get(cnstr_name);
|
||||
constructor_val cnstr_val = cnstr_info.to_constructor_val();
|
||||
expr new_type = res.restore_nested(cnstr_info.get_type());
|
||||
expr new_type = res.restore_nested(cnstr_info.get_type(), aux_env);
|
||||
new_env.add_core(constant_info(constructor_val(cnstr_info.get_name(), cnstr_info.get_univ_params(), new_type,
|
||||
cnstr_val.get_induct(), cnstr_val.get_nparams().get_small_value(),
|
||||
cnstr_val.get_induct(), cnstr_val.get_nparams(),
|
||||
cnstr_val.is_meta())));
|
||||
}
|
||||
/* TODO(leo): restore recursor and reduction rules. */
|
||||
process_rec(mk_rec_name(ind_type.get_name()));
|
||||
}
|
||||
for (name const & aux_rec : aux_rec_names) {
|
||||
process_rec(aux_rec);
|
||||
}
|
||||
return new_env;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue