diff --git a/src/kernel/declaration.h b/src/kernel/declaration.h index cb1dee9df1..00c7a2241e 100644 --- a/src/kernel/declaration.h +++ b/src/kernel/declaration.h @@ -272,8 +272,8 @@ public: inductive_val & operator=(inductive_val const & other) { object_ref::operator=(other); return *this; } inductive_val & operator=(inductive_val && other) { object_ref::operator=(other); return *this; } constant_val const & to_constant_val() const { return static_cast(cnstr_obj_ref(*this, 0)); } - nat const & get_nparams() const { return static_cast(cnstr_obj_ref(*this, 1)); } - nat const & get_nindices() const { return static_cast(cnstr_obj_ref(*this, 2)); } + unsigned get_nparams() const { return static_cast(cnstr_obj_ref(*this, 1)).get_small_value(); } + unsigned get_nindices() const { return static_cast(cnstr_obj_ref(*this, 2)).get_small_value(); } names const & get_all() const { return static_cast(cnstr_obj_ref(*this, 3)); } names const & get_cnstrs() const { return static_cast(cnstr_obj_ref(*this, 4)); } bool is_rec() const { return cnstr_scalar(raw(), sizeof(object*)*5) != 0; } @@ -295,7 +295,7 @@ public: constructor_val & operator=(constructor_val && other) { object_ref::operator=(other); return *this; } constant_val const & to_constant_val() const { return static_cast(cnstr_obj_ref(*this, 0)); } name const & get_induct() const { return static_cast(cnstr_obj_ref(*this, 1)); } - nat const & get_nparams() const { return static_cast(cnstr_obj_ref(*this, 2)); } + unsigned get_nparams() const { return static_cast(cnstr_obj_ref(*this, 2)).get_small_value(); } bool is_meta() const { return cnstr_scalar(raw(), sizeof(object*)*3) != 0; } }; @@ -313,7 +313,7 @@ public: recursor_rule & operator=(recursor_rule const & other) { object_ref::operator=(other); return *this; } recursor_rule & operator=(recursor_rule && other) { object_ref::operator=(other); return *this; } name const & get_constructor() const { return static_cast(cnstr_obj_ref(*this, 0)); } - nat const & get_nfields() const { return static_cast(cnstr_obj_ref(*this, 1)); } + unsigned get_nfields() const { return static_cast(cnstr_obj_ref(*this, 1)).get_small_value(); } expr const & get_rhs() const { return static_cast(cnstr_obj_ref(*this, 2)); } }; @@ -341,10 +341,10 @@ public: recursor_val & operator=(recursor_val && other) { object_ref::operator=(other); return *this; } constant_val const & to_constant_val() const { return static_cast(cnstr_obj_ref(*this, 0)); } names const & get_all() const { return static_cast(cnstr_obj_ref(*this, 1)); } - nat const & get_nparams() const { return static_cast(cnstr_obj_ref(*this, 2)); } - nat const & get_nindices() const { return static_cast(cnstr_obj_ref(*this, 3)); } - nat const & get_nmotives() const { return static_cast(cnstr_obj_ref(*this, 4)); } - nat const & get_nminors() const { return static_cast(cnstr_obj_ref(*this, 5)); } + unsigned get_nparams() const { return static_cast(cnstr_obj_ref(*this, 2)).get_small_value(); } + unsigned get_nindices() const { return static_cast(cnstr_obj_ref(*this, 3)).get_small_value(); } + unsigned get_nmotives() const { return static_cast(cnstr_obj_ref(*this, 4)).get_small_value(); } + unsigned get_nminors() const { return static_cast(cnstr_obj_ref(*this, 5)).get_small_value(); } recursor_rules const & get_rules() const { return static_cast(cnstr_obj_ref(*this, 6)); } bool is_k() const { return cnstr_scalar(raw(), sizeof(object*)*7) != 0; } bool is_meta() const { return cnstr_scalar(raw(), sizeof(object*)*7 + 1) != 0; } diff --git a/src/kernel/inductive.cpp b/src/kernel/inductive.cpp index 30755d8dd9..d5be4763b8 100644 --- a/src/kernel/inductive.cpp +++ b/src/kernel/inductive.cpp @@ -661,27 +661,71 @@ struct elim_nested_inductive_result { } } - expr restore_nested(expr e) const { + /* If `c` is an constructor name associated with an auxiliary inductive type, then return the + nested inductive associated with it and the name of its inductive type. Return none. */ + optional> get_nested_if_aux_constructor(environment const & aux_env, name const & c) const { + optional info = aux_env.find(c); + if (!info) return optional>(); + name auxI_name = info->to_constructor_val().get_induct(); + expr const * nested = m_aux2nested.find(auxI_name); + if (!nested) return optional>(); + return optional>(*nested, auxI_name); + } + + name restore_constructor_name(environment const & aux_env, name const & cnstr_name) const { + optional> p = get_nested_if_aux_constructor(aux_env, cnstr_name); + lean_assert(p); + expr const & I = get_app_fn(p->first); + lean_assert(is_constant(I)); + return cnstr_name.replace_prefix(p->second, const_name(I)); + } + + expr restore_nested(expr e, environment const & aux_env, name_map const & aux_rec_name_map = name_map()) const { + lean_assert(is_pi(e) || is_lambda(e)); local_ctx lctx; buffer As; + bool pi = is_pi(e); for (unsigned i = 0; i < m_params.size(); i++) { - lean_assert(is_pi(e)); + lean_assert(is_pi(e) || is_lambda(e)); As.push_back(lctx.mk_local_decl(binding_name(e), binding_domain(e), binding_info(e))); e = instantiate(binding_body(e), As.back()); } e = replace(e, [&](expr const & t, unsigned) { - if (!is_app(t)) return none_expr(); + if (is_constant(t)) { + if (name const * rec_name = aux_rec_name_map.find(const_name(t))) { + return some_expr(mk_constant(*rec_name, const_levels(t))); + } + } expr const & fn = get_app_fn(t); - if (!is_constant(fn)) return none_expr(); - expr const * nested = m_aux2nested.find(const_name(fn)); - if (!nested) return none_expr(); - buffer args; - get_app_args(t, args); - lean_assert(args.size() >= m_params.size()); - expr new_t = instantiate_rev(abstract(*nested, m_params.size(), m_params.data()), As.size(), As.data()); - return some_expr(mk_app(new_t, args.size() - m_params.size(), args.data() + m_params.size())); + if (is_constant(fn)) { + if (expr const * nested = m_aux2nested.find(const_name(fn))) { + buffer args; + get_app_args(t, args); + lean_assert(args.size() >= m_params.size()); + expr new_t = instantiate_rev(abstract(*nested, m_params.size(), m_params.data()), As.size(), As.data()); + return some_expr(mk_app(new_t, args.size() - m_params.size(), args.data() + m_params.size())); + } + if (optional> r = get_nested_if_aux_constructor(aux_env, const_name(fn))) { + expr nested = r->first; + name auxI_name = r->second; + /* `t` is a constructor-application of an auxiliary inductive type */ + buffer args; + get_app_args(t, args); + lean_assert(args.size() >= m_params.size()); + expr new_nested = instantiate_rev(abstract(nested, m_params.size(), m_params.data()), As.size(), As.data()); + buffer I_args; + expr I = get_app_args(new_nested, I_args); + lean_assert(I_args.size() == m_params.size()); + lean_assert(is_constant(I)); + name new_fn_name = const_name(fn).replace_prefix(auxI_name, const_name(I)); + expr new_fn = mk_constant(new_fn_name, const_levels(I)); + expr new_t = mk_app(mk_app(new_fn, I_args), args.size() - I_args.size(), args.data() + I_args.size()); + return some_expr(new_t); + } + } + return none_expr(); }); - return lctx.mk_pi(As, e); + return pi ? lctx.mk_pi(As, e) : lctx.mk_lambda(As, e); } }; @@ -737,11 +781,11 @@ struct elim_nested_inductive_fn { if (!info || !info->is_inductive()) return optional(); buffer args; get_app_args(e, args); - nat const & nparams = info->to_inductive_val().get_nparams(); + unsigned nparams = info->to_inductive_val().get_nparams(); if (nparams > args.size()) return optional(); bool is_nested = false; bool loose_bvars = false; - for (unsigned i = 0; i < nparams.get_small_value(); i++) { + for (unsigned i = 0; i < nparams; i++) { if (has_loose_bvars(args[i])) { loose_bvars = true; } @@ -781,7 +825,7 @@ struct elim_nested_inductive_fn { name const & I_name = const_name(fn); levels const & I_lvls = const_levels(fn); lean_assert(I_val->get_nparams() <= args.size()); - unsigned I_nparams = I_val->get_nparams().get_small_value(); + unsigned I_nparams = I_val->get_nparams(); expr IAs = mk_app(fn, I_nparams, args.data()); /* `I As` */ /* Check whether we have already created an auxiliary inductive_type for `I As` */ optional auxI_name; @@ -888,6 +932,43 @@ struct elim_nested_inductive_fn { } }; +/* Given the auxiliary environment `aux_env` generated by processing the auxiliary mutual declaration, + and the original declaration `d`. This function return a pair `(aux_rec_names, aux_rec_name_map)` + where `aux_rec_names` contains the recursor names associated to auxiliary inductive types used to + eliminated nested inductive occurrences. + The mapping `aux_rec_name_map` contains an entry `(aux_rec_name -> rec_name)` for each + element in `aux_rec_names`. It provides the new names for these recursors. + + We compute the new recursor names using the first inductive datatype in the original declaration `d`, + and the suffice `.rec_`. */ +static pair> mk_aux_rec_name_map(environment const & aux_env, inductive_decl const & d) { + unsigned ntypes = length(d.get_types()); + lean_assert(ntypes > 0); + inductive_type const & main_type = head(d.get_types()); + name const & main_name = main_type.get_name(); + constant_info main_info = aux_env.get(main_name); + names const & all_names = main_info.to_inductive_val().get_all(); + /* This function is only called if we have created auxiliary inductive types when eliminating + the nested inductives. */ + lean_assert(length(all_names) > ntypes); + /* Remark: we use the `main_name` to declarate the auxiliary recursors as: .rec_1, .rec_2, ... + This is a little bit asymmetrical if `d` is a mutual declaration, but it makes sure we have simple names. */ + buffer old_rec_names; + name_map rec_map; + unsigned i = 0; + unsigned next_idx = 1; + for (name const & ind_name : all_names) { + if (i >= ntypes) { + old_rec_names.push_back(mk_rec_name(ind_name)); + name new_rec_name = mk_rec_name(main_name).append_after(next_idx); + next_idx++; + rec_map.insert(old_rec_names.back(), new_rec_name); + } + i++; + } + return mk_pair(names(old_rec_names), rec_map); +} + environment environment::add_inductive(declaration const & d) const { elim_nested_inductive_result res = elim_nested_inductive_fn(*this, d)(); environment aux_env = add_inductive_fn(*this, inductive_decl(res.m_aux_decl))(); @@ -896,9 +977,35 @@ environment environment::add_inductive(declaration const & d) const { return aux_env; } else { /* Restore nested inductives. */ - environment new_env = *this; inductive_decl ind_d(d); names all_ind_names = get_all_inductive_names(ind_d); + names aux_rec_names; name_map aux_rec_name_map; + std::tie(aux_rec_names, aux_rec_name_map) = mk_aux_rec_name_map(aux_env, d); + environment new_env = *this; + auto process_rec = [&](name const & rec_name) { + name new_rec_name = rec_name; + if (name const * new_name = aux_rec_name_map.find(rec_name)) + new_rec_name = *new_name; + constant_info rec_info = aux_env.get(rec_name); + expr new_rec_type = res.restore_nested(rec_info.get_type(), aux_env, aux_rec_name_map); + recursor_val rec_val = rec_info.to_recursor_val(); + buffer new_rules; + for (recursor_rule const & rule : rec_val.get_rules()) { + expr new_rhs = res.restore_nested(rule.get_rhs(), aux_env, aux_rec_name_map); + name cnstr_name = rule.get_constructor(); + name new_cnstr_name = cnstr_name; + if (new_rec_name != rec_name) { + /* We need to fix the constructor name */ + new_cnstr_name = res.restore_constructor_name(aux_env, cnstr_name); + } + new_rules.push_back(recursor_rule(new_cnstr_name, rule.get_nfields(), new_rhs)); + } + new_env.check_name(new_rec_name); + new_env.add_core(constant_info(recursor_val(new_rec_name, rec_info.get_univ_params(), new_rec_type, + all_ind_names, rec_val.get_nparams(), rec_val.get_nindices(), rec_val.get_nmotives(), + rec_val.get_nminors(), recursor_rules(new_rules), + rec_val.is_k(), rec_val.is_meta()))); + }; for (inductive_type const & ind_type : ind_d.get_types()) { constant_info ind_info = aux_env.get(ind_type.get_name()); inductive_val ind_val = ind_info.to_inductive_val(); @@ -906,17 +1013,20 @@ environment environment::add_inductive(declaration const & d) const { Remark: if we decide to store the recursor names, we will also need to fix it. */ new_env.add_core(constant_info(inductive_val(ind_info.get_name(), ind_info.get_univ_params(), ind_info.get_type(), - ind_val.get_nparams().get_small_value(), ind_val.get_nindices().get_small_value(), + ind_val.get_nparams(), ind_val.get_nindices(), all_ind_names, ind_val.get_cnstrs(), ind_val.is_rec(), ind_val.is_meta()))); for (name const & cnstr_name : ind_val.get_cnstrs()) { constant_info cnstr_info = aux_env.get(cnstr_name); constructor_val cnstr_val = cnstr_info.to_constructor_val(); - expr new_type = res.restore_nested(cnstr_info.get_type()); + expr new_type = res.restore_nested(cnstr_info.get_type(), aux_env); new_env.add_core(constant_info(constructor_val(cnstr_info.get_name(), cnstr_info.get_univ_params(), new_type, - cnstr_val.get_induct(), cnstr_val.get_nparams().get_small_value(), + cnstr_val.get_induct(), cnstr_val.get_nparams(), cnstr_val.is_meta()))); } - /* TODO(leo): restore recursor and reduction rules. */ + process_rec(mk_rec_name(ind_type.get_name())); + } + for (name const & aux_rec : aux_rec_names) { + process_rec(aux_rec); } return new_env; }