From 517923d362e970ff8f671405f741e91ce2bf072b Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 31 Aug 2018 17:47:22 -0700 Subject: [PATCH] feat(kernel/inductive): generate recursors in the new inductive datatype module --- library/init/lean/declaration.lean | 2 +- src/frontends/lean/print_cmd.cpp | 2 + src/kernel/declaration.cpp | 18 ++- src/kernel/declaration.h | 13 +- src/kernel/inductive.cpp | 199 ++++++++++++++++++++++++++--- src/util/object_ref.h | 23 ++++ 6 files changed, 235 insertions(+), 22 deletions(-) diff --git a/library/init/lean/declaration.lean b/library/init/lean/declaration.lean index a8f809b9f8..60e80eca5f 100644 --- a/library/init/lean/declaration.lean +++ b/library/init/lean/declaration.lean @@ -99,8 +99,8 @@ structure recursor_val extends constant_val := (nindices : nat) -- Number of indices (nmotives : nat) -- Number of motives (nminor : nat) -- Number of minor premises -(k : bool) -- It supports K-like reduction (rules : list recursor_rule) -- A reduction for each constructor +(k : bool) -- It supports K-like reduction (is_meta : bool) inductive quot_kind diff --git a/src/frontends/lean/print_cmd.cpp b/src/frontends/lean/print_cmd.cpp index e4898bb9e1..32ba5b8400 100644 --- a/src/frontends/lean/print_cmd.cpp +++ b/src/frontends/lean/print_cmd.cpp @@ -383,6 +383,8 @@ void print_id_info(parser & p, message_builder & out, name const & id, bool show out << n << "\n"; } else if (d.is_constructor()) { print_constant(p, out, "(new) constructor", d); + } else if (d.is_recursor()) { + print_constant(p, out, "(new) recursor", d); } // print_patterns(p, c); } diff --git a/src/kernel/declaration.cpp b/src/kernel/declaration.cpp index 99d66e3eda..ca088b4904 100644 --- a/src/kernel/declaration.cpp +++ b/src/kernel/declaration.cpp @@ -84,12 +84,22 @@ constructor_val::constructor_val(name const & n, level_param_names const & lpara cnstr_set_scalar(raw(), sizeof(object*)*3, static_cast(is_meta)); } +recursor_val::recursor_val(name const & n, level_param_names const & lparams, expr const & type, + names const & all, unsigned nparams, unsigned nindices, unsigned nmotives, + unsigned nminors, recursor_rules const & rules, bool k, bool is_meta): + object_ref(mk_cnstr(0, constant_val(n, lparams, type), all, nat(nparams), nat(nindices), nat(nmotives), + nat(nminors), rules, 2)) { + cnstr_set_scalar(raw(), sizeof(object*)*7, static_cast(k)); + cnstr_set_scalar(raw(), sizeof(object*)*7 + 1, static_cast(is_meta)); +} + + bool declaration::is_meta() const { switch (kind()) { case declaration_kind::Definition: return to_definition_val().is_meta(); case declaration_kind::Axiom: return to_axiom_val().is_meta(); case declaration_kind::Theorem: return false; - case declaration_kind::Inductive: lean_unreachable(); // TODO(Leo): + case declaration_kind::Inductive: return inductive_decl(*this).is_meta(); case declaration_kind::Quot: return false; case declaration_kind::MutualDefinition: return true; } @@ -220,6 +230,10 @@ constant_info::constant_info(constructor_val const & v): object_ref(mk_cnstr(static_cast(constant_info_kind::Constructor), v)) { } +constant_info::constant_info(recursor_val const & v): + object_ref(mk_cnstr(static_cast(constant_info_kind::Recursor), v)) { +} + static reducibility_hints * g_opaque = nullptr; reducibility_hints const & constant_info::get_hints() const { @@ -237,7 +251,7 @@ bool constant_info::is_meta() const { case constant_info_kind::Quot: return false; case constant_info_kind::Inductive: return to_inductive_val().is_meta(); case constant_info_kind::Constructor: return to_constructor_val().is_meta(); - case constant_info_kind::Recursor: return false; // TODO(Leo): to_recursor_val().is_meta(); + case constant_info_kind::Recursor: return to_recursor_val().is_meta(); } lean_unreachable(); } diff --git a/src/kernel/declaration.h b/src/kernel/declaration.h index c8360090bf..cb1dee9df1 100644 --- a/src/kernel/declaration.h +++ b/src/kernel/declaration.h @@ -326,12 +326,15 @@ structure recursor_val extends constant_val := (nindices : nat) -- Number of indices (nmotives : nat) -- Number of motives (nminors : nat) -- Number of minor premises -(k : bool) -- It supports K-like reduction (rules : list recursor_rule) -- A reduction for each constructor +(k : bool) -- It supports K-like reduction (is_meta : bool) */ class recursor_val : public object_ref { public: + recursor_val(name const & n, level_param_names const & lparams, expr const & type, + names const & all, unsigned nparams, unsigned nindices, unsigned nmotives, + unsigned nminors, recursor_rules const & rules, bool k, bool is_meta); recursor_val(recursor_val const & other):object_ref(other) {} recursor_val(recursor_val && other):object_ref(other) {} recursor_val & operator=(recursor_val const & other) { object_ref::operator=(other); return *this; } @@ -343,8 +346,8 @@ public: nat const & get_nmotives() const { return static_cast(cnstr_obj_ref(*this, 4)); } nat const & get_nminors() const { return static_cast(cnstr_obj_ref(*this, 5)); } recursor_rules const & get_rules() const { return static_cast(cnstr_obj_ref(*this, 6)); } - bool is_k() const; - bool is_meta() const; + bool is_k() const { return cnstr_scalar(raw(), sizeof(object*)*7) != 0; } + bool is_meta() const { return cnstr_scalar(raw(), sizeof(object*)*7 + 1) != 0; } }; enum class quot_kind { Type, Mk, Lift, Ind }; @@ -396,6 +399,7 @@ public: constant_info(quot_val const & v); constant_info(inductive_val const & v); constant_info(constructor_val const & v); + constant_info(recursor_val const & v); constant_info(constant_info const & other):object_ref(other) {} constant_info(constant_info && other):object_ref(other) {} @@ -413,6 +417,7 @@ public: bool is_theorem() const { return kind() == constant_info_kind::Theorem; } bool is_inductive() const { return kind() == constant_info_kind::Inductive; } bool is_constructor() const { return kind() == constant_info_kind::Constructor; } + bool is_recursor() const { return kind() == constant_info_kind::Recursor; } name const & get_name() const { return to_constant_val().get_name(); } level_param_names const & get_univ_params() const { return to_constant_val().get_lparams(); } @@ -427,7 +432,7 @@ public: theorem_val const & to_theorem_val() const { lean_assert(is_theorem()); return static_cast(to_val()); } inductive_val const & to_inductive_val() const { lean_assert(is_inductive()); return static_cast(to_val()); } constructor_val const & to_constructor_val() const { lean_assert(is_constructor()); return static_cast(to_val()); } - // recursor_val const & to_recursor_val() const { lean_assert(is_recursor()); return static_cast(to_val()); } + recursor_val const & to_recursor_val() const { lean_assert(is_recursor()); return static_cast(to_val()); } }; inline optional none_constant_info() { return optional(); } diff --git a/src/kernel/inductive.cpp b/src/kernel/inductive.cpp index 68ff23cb72..b1cd8bfbd3 100644 --- a/src/kernel/inductive.cpp +++ b/src/kernel/inductive.cpp @@ -49,11 +49,8 @@ class add_inductive_fn { bool m_K_target; struct rec_info { - name m_name; - local_ctx m_lctx; - buffer m_Cs; /* free variables for all motives */ expr m_C; /* free variable for "main" motive */ - buffer m_minor; /* minor premises */ + buffer m_minors; /* minor premises */ buffer m_indices; expr m_major; /* major premise */ }; @@ -78,6 +75,21 @@ public: return m_lctx.get_local_decl(m_params[i]).get_type(); } + expr mk_local_decl(local_ctx & lctx, name const & n, expr const & t, binder_info const & bi = binder_info()) { + return lctx.mk_local_decl(m_ngen, n, t, bi); + } + + expr mk_local_decl_for(local_ctx & lctx, expr const & t) { + lean_assert(is_pi(t)); + return lctx.mk_local_decl(m_ngen, binding_name(t), binding_domain(t), binding_info(t)); + } + + expr whnf(local_ctx const & lctx, expr const & t) { return tc(lctx).whnf(t); } + + expr infer_type(local_ctx const & lctx, expr const & t) { return tc(lctx).infer(t); } + + bool is_def_eq(local_ctx const & lctx, expr const & t1, expr const & t2) { return tc(lctx).is_def_eq(t1, t2); } + /** \brief Check whether the type of each datatype is well typed, and do not contain free variables or meta variables, all inductive datatypes have the same parameters, the number of parameters match the argument m_nparams, @@ -105,11 +117,11 @@ public: while (is_pi(type)) { if (i < m_nparams) { if (first) { - expr param = m_lctx.mk_local_decl(m_ngen, binding_name(type), binding_domain(type), binding_info(type)); + expr param = mk_local_decl_for(m_lctx, type); m_params.push_back(param); type = instantiate(binding_body(type), param); } else { - if (!tc(m_lctx).is_def_eq(binding_domain(type), get_param_type(i))) + if (!is_def_eq(m_lctx, binding_domain(type), get_param_type(i))) throw kernel_exception(m_env, "parameters of all inductive datatypes must match"); type = instantiate(binding_body(type), m_params[i]); } @@ -230,24 +242,24 @@ public: /** \brief Return `some(d_idx)` iff `t` is a recursive argument, `d_idx` is the index of the recursive inductive datatype. Otherwise, return `none`. */ optional is_rec_argument(local_ctx lctx, expr t) { - t = tc(lctx).whnf(t); + t = whnf(lctx, t); while (is_pi(t)) { - expr local = lctx.mk_local_decl(m_ngen, binding_name(t), binding_domain(t), binding_info(t)); - t = tc(lctx).whnf(instantiate(binding_body(t), local)); + expr local = mk_local_decl_for(lctx, t); + t = whnf(lctx, instantiate(binding_body(t), local)); } return is_valid_ind_app(t); } /** \brief Check if \c t contains only positive occurrences of the inductive datatypes being declared. */ void check_positivity(local_ctx lctx, expr t, name const & cnstr_name, int arg_idx) { - t = tc(lctx).whnf(t); + t = whnf(lctx, t); if (!has_ind_occ(t)) { // nonrecursive argument } else if (is_pi(t)) { if (has_ind_occ(binding_domain(t))) throw kernel_exception(m_env, sstream() << "arg #" << (arg_idx + 1) << " of '" << cnstr_name << "' " "has a non positive occurrence of the datatypes being declared"); - expr local = lctx.mk_local_decl(m_ngen, binding_name(t), binding_domain(t), binding_info(t)); + expr local = mk_local_decl_for(lctx, t); check_positivity(lctx, instantiate(binding_body(t), local), cnstr_name, arg_idx); } else if (is_valid_ind_app(t)) { // recursive argument @@ -272,7 +284,7 @@ public: local_ctx lctx = m_lctx; while (is_pi(t)) { if (i < m_nparams) { - if (!tc(lctx).is_def_eq(binding_domain(t), get_param_type(i))) + if (!is_def_eq(lctx, binding_domain(t), get_param_type(i))) throw kernel_exception(m_env, sstream() << "arg #" << (i + 1) << " of '" << n << "' " << "does not match inductive datatypes parameters'"); t = instantiate(binding_body(t), m_params[i]); @@ -287,7 +299,7 @@ public: } if (!m_is_meta) check_positivity(lctx, binding_domain(t), n, i); - expr local = lctx.mk_local_decl(m_ngen, binding_name(t), binding_domain(t), binding_info(t)); + expr local = mk_local_decl_for(lctx, t); t = instantiate(binding_body(t), local); } i++; @@ -346,7 +358,7 @@ public: buffer to_check; /* Arguments that we must check if occur in the result type */ local_ctx lctx; while (is_pi(type)) { - expr fvar = lctx.mk_local_decl(m_ngen, binding_name(type), binding_domain(type), binding_info(type)); + expr fvar = mk_local_decl_for(lctx, type); if (i >= m_nparams) { expr s = tc(lctx).ensure_type(binding_domain(type)); if (!is_zero(sort_level(s))) { @@ -381,7 +393,6 @@ public: } m_elim_level = mk_univ_param(u); } - // std::cout << ">> elim_level: " << m_elim_level << "\n"; } void init_K_target() { @@ -408,6 +419,162 @@ public: } } + /** \brief Given `t` of the form `I As is` where `I` is one of the inductive datatypes being defined, + As are the global parameters, and is the actual indices provided to it. + Return the index of `I`, and store is in the argument `indices`. */ + unsigned get_I_indices(expr const & t, buffer & indices) { + optional r = is_valid_ind_app(t); + lean_assert(r); + buffer all_args; + get_app_args(t, all_args); + for (unsigned i = m_nparams; i < all_args.size(); i++) + indices.push_back(all_args[i]); + return *r; + } + + /** \brief Populate m_rec_infos. */ + void mk_rec_infos() { + unsigned d_idx = 0; + /* First, populate the fields, m_C, m_indices, m_lctx, m_major */ + for (inductive_type const & ind_type : m_ind_types) { + rec_info info; + expr t = ind_type.get_type(); + unsigned i = 0; + while (is_pi(t)) { + if (i < m_nparams) { + t = instantiate(binding_body(t), m_params[i]); + } else { + expr idx = mk_local_decl_for(m_lctx, t); + info.m_indices.push_back(idx); + t = instantiate(binding_body(t), idx); + } + i++; + } + info.m_major = mk_local_decl(m_lctx, "t", + mk_app(mk_app(m_ind_cnsts[d_idx], m_params), info.m_indices)); + expr C_ty = mk_sort(m_elim_level); + C_ty = m_lctx.mk_pi(info.m_major, C_ty); + C_ty = m_lctx.mk_pi(info.m_indices, C_ty); + name C_name("C"); + if (m_ind_types.size() > 1) + C_name = name(C_name).append_after(d_idx+1); + info.m_C = mk_local_decl(m_lctx, C_name, C_ty); + m_rec_infos.push_back(info); + d_idx++; + } + /* First, populate the field m_minors */ + unsigned minor_idx = 1; + d_idx = 0; + for (inductive_type const & ind_type : m_ind_types) { + for (constructor const & cnstr : ind_type.get_cnstrs()) { + buffer b_u; // nonrec and rec args; + buffer u; // rec args + buffer v; // inductive args + expr t = constructor_type(cnstr); + unsigned i = 0; + while (is_pi(t)) { + if (i < m_nparams) { + t = instantiate(binding_body(t), m_params[i]); + } else { + expr l = mk_local_decl_for(m_lctx, t); + b_u.push_back(l); + if (is_rec_argument(m_lctx, binding_domain(t))) + u.push_back(l); + t = instantiate(binding_body(t), l); + } + i++; + } + buffer it_indices; + unsigned it_idx = get_I_indices(t, it_indices); + expr C_app = mk_app(m_rec_infos[it_idx].m_C, it_indices); + expr intro_app = mk_app(mk_app(mk_constant(constructor_name(cnstr), m_levels), m_params), b_u); + C_app = mk_app(C_app, intro_app); + /* populate v using u */ + for (unsigned i = 0; i < u.size(); i++) { + expr u_i = u[i]; + expr u_i_ty = whnf(m_lctx, infer_type(m_lctx, u_i)); + buffer xs; + while (is_pi(u_i_ty)) { + expr x = mk_local_decl_for(m_lctx, u_i_ty); + xs.push_back(x); + u_i_ty = whnf(m_lctx, instantiate(binding_body(u_i_ty), x)); + } + buffer it_indices; + unsigned it_idx = get_I_indices(u_i_ty, it_indices); + expr C_app = mk_app(m_rec_infos[it_idx].m_C, it_indices); + expr u_app = mk_app(u_i, xs); + C_app = mk_app(C_app, u_app); + expr v_i_ty = m_lctx.mk_pi(xs, C_app); + expr v_i = mk_local_decl(m_lctx, name("v").append_after(i), v_i_ty, binder_info()); + v.push_back(v_i); + } + expr minor_ty = m_lctx.mk_pi(b_u, m_lctx.mk_pi(v, C_app)); + expr minor = mk_local_decl(m_lctx, name("m").append_after(minor_idx), minor_ty); + m_rec_infos[d_idx].m_minors.push_back(minor); + minor_idx++; + } + d_idx++; + } + } + + /** \brief Return the levels for the recursor. */ + levels get_rec_level_params() { + if (is_param(m_elim_level)) + return levels(m_elim_level, m_levels); + else + return m_levels; + } + + /** \brief Return the level parameter names for the recursor. */ + names get_rec_level_param_names() { + if (is_param(m_elim_level)) + return level_param_names(param_id(m_elim_level), m_lparams); + else + return m_lparams; + } + + /** \brief Declare recursors. */ + void declare_recursors() { + names all = get_all_names(); + for (unsigned d_idx = 0; d_idx < m_ind_types.size(); d_idx++) { + rec_info const & info = m_rec_infos[d_idx]; + expr C_app = mk_app(mk_app(info.m_C, info.m_indices), info.m_major); + expr rec_ty = m_lctx.mk_pi(info.m_major, C_app); + rec_ty = m_lctx.mk_pi(info.m_indices, rec_ty); + /* Add minor premises */ + unsigned nminors = 0; + unsigned i = m_ind_types.size(); + while (i > 0) { + --i; + unsigned j = m_rec_infos[i].m_minors.size(); + while (j > 0) { + --j; + rec_ty = m_lctx.mk_pi(m_rec_infos[i].m_minors[j], rec_ty); + nminors++; + } + } + /* Add type formers (aka motives) */ + unsigned nmotives = 0; + i = m_ind_types.size(); + while (i > 0) { + --i; + rec_ty = m_lctx.mk_pi(m_rec_infos[i].m_C, rec_ty); + nmotives++; + } + rec_ty = m_lctx.mk_pi(m_params, rec_ty); + rec_ty = infer_implicit(rec_ty, true /* strict */); + /* + TODO(Leo): gen reduction rule + */ + recursor_rules rules; + name rec_name = mk_rec_name(m_ind_types[d_idx].get_name()); + names rec_lparams = get_rec_level_param_names(); + m_env.add_core(constant_info(recursor_val(rec_name, rec_lparams, rec_ty, all, + m_nparams, m_nindices[d_idx], nmotives, nminors, + rules, m_K_target, m_is_meta))); + } + } + environment operator()() { m_env.check_duplicated_univ_params(m_lparams); check_inductive_types(); @@ -416,6 +583,8 @@ public: declare_constructors(); init_elim_level(); init_K_target(); + mk_rec_infos(); + declare_recursors(); return m_env; } }; diff --git a/src/util/object_ref.h b/src/util/object_ref.h index dec8ddbd26..88c888c4b2 100644 --- a/src/util/object_ref.h +++ b/src/util/object_ref.h @@ -95,6 +95,29 @@ inline object_ref mk_cnstr(unsigned tag, object_ref const & o1, object_ref const return object_ref(r); } +inline object_ref mk_cnstr(unsigned tag, object_ref const & o1, object_ref const & o2, object_ref const & o3, object_ref const & o4, object_ref const & o5, object_ref const & o6, unsigned scalar_sz = 0) { + object * r = alloc_cnstr(tag, 6, scalar_sz); + cnstr_set_obj(r, 0, o1.raw()); inc(o1.raw()); + cnstr_set_obj(r, 1, o2.raw()); inc(o2.raw()); + cnstr_set_obj(r, 2, o3.raw()); inc(o3.raw()); + cnstr_set_obj(r, 3, o4.raw()); inc(o4.raw()); + cnstr_set_obj(r, 4, o5.raw()); inc(o5.raw()); + cnstr_set_obj(r, 5, o6.raw()); inc(o6.raw()); + return object_ref(r); +} + +inline object_ref mk_cnstr(unsigned tag, object_ref const & o1, object_ref const & o2, object_ref const & o3, object_ref const & o4, object_ref const & o5, object_ref const & o6, object_ref const & o7, unsigned scalar_sz = 0) { + object * r = alloc_cnstr(tag, 6, scalar_sz); + cnstr_set_obj(r, 0, o1.raw()); inc(o1.raw()); + cnstr_set_obj(r, 1, o2.raw()); inc(o2.raw()); + cnstr_set_obj(r, 2, o3.raw()); inc(o3.raw()); + cnstr_set_obj(r, 3, o4.raw()); inc(o4.raw()); + cnstr_set_obj(r, 4, o5.raw()); inc(o5.raw()); + cnstr_set_obj(r, 5, o6.raw()); inc(o6.raw()); + cnstr_set_obj(r, 6, o7.raw()); inc(o7.raw()); + return object_ref(r); +} + /* The following definition is a low level hack that relies on the fact that sizeof(object_ref) == sizeof(object *). */ inline object_ref const & cnstr_obj_ref(object * o, unsigned i) { static_assert(sizeof(object_ref) == sizeof(object *), "unexpected object_ref size"); // NOLINT