feat(kernel/inductive): generate recursors in the new inductive datatype module

This commit is contained in:
Leonardo de Moura 2018-08-31 17:47:22 -07:00
parent 2fb677f1d0
commit 517923d362
6 changed files with 235 additions and 22 deletions

View file

@ -99,8 +99,8 @@ structure recursor_val extends constant_val :=
(nindices : nat) -- Number of indices
(nmotives : nat) -- Number of motives
(nminor : nat) -- Number of minor premises
(k : bool) -- It supports K-like reduction
(rules : list recursor_rule) -- A reduction for each constructor
(k : bool) -- It supports K-like reduction
(is_meta : bool)
inductive quot_kind

View file

@ -383,6 +383,8 @@ void print_id_info(parser & p, message_builder & out, name const & id, bool show
out << n << "\n";
} else if (d.is_constructor()) {
print_constant(p, out, "(new) constructor", d);
} else if (d.is_recursor()) {
print_constant(p, out, "(new) recursor", d);
}
// print_patterns(p, c);
}

View file

@ -84,12 +84,22 @@ constructor_val::constructor_val(name const & n, level_param_names const & lpara
cnstr_set_scalar<unsigned char>(raw(), sizeof(object*)*3, static_cast<unsigned char>(is_meta));
}
recursor_val::recursor_val(name const & n, level_param_names const & lparams, expr const & type,
names const & all, unsigned nparams, unsigned nindices, unsigned nmotives,
unsigned nminors, recursor_rules const & rules, bool k, bool is_meta):
object_ref(mk_cnstr(0, constant_val(n, lparams, type), all, nat(nparams), nat(nindices), nat(nmotives),
nat(nminors), rules, 2)) {
cnstr_set_scalar<unsigned char>(raw(), sizeof(object*)*7, static_cast<unsigned char>(k));
cnstr_set_scalar<unsigned char>(raw(), sizeof(object*)*7 + 1, static_cast<unsigned char>(is_meta));
}
bool declaration::is_meta() const {
switch (kind()) {
case declaration_kind::Definition: return to_definition_val().is_meta();
case declaration_kind::Axiom: return to_axiom_val().is_meta();
case declaration_kind::Theorem: return false;
case declaration_kind::Inductive: lean_unreachable(); // TODO(Leo):
case declaration_kind::Inductive: return inductive_decl(*this).is_meta();
case declaration_kind::Quot: return false;
case declaration_kind::MutualDefinition: return true;
}
@ -220,6 +230,10 @@ constant_info::constant_info(constructor_val const & v):
object_ref(mk_cnstr(static_cast<unsigned>(constant_info_kind::Constructor), v)) {
}
constant_info::constant_info(recursor_val const & v):
object_ref(mk_cnstr(static_cast<unsigned>(constant_info_kind::Recursor), v)) {
}
static reducibility_hints * g_opaque = nullptr;
reducibility_hints const & constant_info::get_hints() const {
@ -237,7 +251,7 @@ bool constant_info::is_meta() const {
case constant_info_kind::Quot: return false;
case constant_info_kind::Inductive: return to_inductive_val().is_meta();
case constant_info_kind::Constructor: return to_constructor_val().is_meta();
case constant_info_kind::Recursor: return false; // TODO(Leo): to_recursor_val().is_meta();
case constant_info_kind::Recursor: return to_recursor_val().is_meta();
}
lean_unreachable();
}

View file

@ -326,12 +326,15 @@ structure recursor_val extends constant_val :=
(nindices : nat) -- Number of indices
(nmotives : nat) -- Number of motives
(nminors : nat) -- Number of minor premises
(k : bool) -- It supports K-like reduction
(rules : list recursor_rule) -- A reduction for each constructor
(k : bool) -- It supports K-like reduction
(is_meta : bool)
*/
class recursor_val : public object_ref {
public:
recursor_val(name const & n, level_param_names const & lparams, expr const & type,
names const & all, unsigned nparams, unsigned nindices, unsigned nmotives,
unsigned nminors, recursor_rules const & rules, bool k, bool is_meta);
recursor_val(recursor_val const & other):object_ref(other) {}
recursor_val(recursor_val && other):object_ref(other) {}
recursor_val & operator=(recursor_val const & other) { object_ref::operator=(other); return *this; }
@ -343,8 +346,8 @@ public:
nat const & get_nmotives() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 4)); }
nat const & get_nminors() const { return static_cast<nat const &>(cnstr_obj_ref(*this, 5)); }
recursor_rules const & get_rules() const { return static_cast<recursor_rules const &>(cnstr_obj_ref(*this, 6)); }
bool is_k() const;
bool is_meta() const;
bool is_k() const { return cnstr_scalar<unsigned char>(raw(), sizeof(object*)*7) != 0; }
bool is_meta() const { return cnstr_scalar<unsigned char>(raw(), sizeof(object*)*7 + 1) != 0; }
};
enum class quot_kind { Type, Mk, Lift, Ind };
@ -396,6 +399,7 @@ public:
constant_info(quot_val const & v);
constant_info(inductive_val const & v);
constant_info(constructor_val const & v);
constant_info(recursor_val const & v);
constant_info(constant_info const & other):object_ref(other) {}
constant_info(constant_info && other):object_ref(other) {}
@ -413,6 +417,7 @@ public:
bool is_theorem() const { return kind() == constant_info_kind::Theorem; }
bool is_inductive() const { return kind() == constant_info_kind::Inductive; }
bool is_constructor() const { return kind() == constant_info_kind::Constructor; }
bool is_recursor() const { return kind() == constant_info_kind::Recursor; }
name const & get_name() const { return to_constant_val().get_name(); }
level_param_names const & get_univ_params() const { return to_constant_val().get_lparams(); }
@ -427,7 +432,7 @@ public:
theorem_val const & to_theorem_val() const { lean_assert(is_theorem()); return static_cast<theorem_val const &>(to_val()); }
inductive_val const & to_inductive_val() const { lean_assert(is_inductive()); return static_cast<inductive_val const &>(to_val()); }
constructor_val const & to_constructor_val() const { lean_assert(is_constructor()); return static_cast<constructor_val const &>(to_val()); }
// recursor_val const & to_recursor_val() const { lean_assert(is_recursor()); return static_cast<recursor_val const &>(to_val()); }
recursor_val const & to_recursor_val() const { lean_assert(is_recursor()); return static_cast<recursor_val const &>(to_val()); }
};
inline optional<constant_info> none_constant_info() { return optional<constant_info>(); }

View file

@ -49,11 +49,8 @@ class add_inductive_fn {
bool m_K_target;
struct rec_info {
name m_name;
local_ctx m_lctx;
buffer<expr> m_Cs; /* free variables for all motives */
expr m_C; /* free variable for "main" motive */
buffer<expr> m_minor; /* minor premises */
buffer<expr> m_minors; /* minor premises */
buffer<expr> m_indices;
expr m_major; /* major premise */
};
@ -78,6 +75,21 @@ public:
return m_lctx.get_local_decl(m_params[i]).get_type();
}
expr mk_local_decl(local_ctx & lctx, name const & n, expr const & t, binder_info const & bi = binder_info()) {
return lctx.mk_local_decl(m_ngen, n, t, bi);
}
expr mk_local_decl_for(local_ctx & lctx, expr const & t) {
lean_assert(is_pi(t));
return lctx.mk_local_decl(m_ngen, binding_name(t), binding_domain(t), binding_info(t));
}
expr whnf(local_ctx const & lctx, expr const & t) { return tc(lctx).whnf(t); }
expr infer_type(local_ctx const & lctx, expr const & t) { return tc(lctx).infer(t); }
bool is_def_eq(local_ctx const & lctx, expr const & t1, expr const & t2) { return tc(lctx).is_def_eq(t1, t2); }
/**
\brief Check whether the type of each datatype is well typed, and do not contain free variables or meta variables,
all inductive datatypes have the same parameters, the number of parameters match the argument m_nparams,
@ -105,11 +117,11 @@ public:
while (is_pi(type)) {
if (i < m_nparams) {
if (first) {
expr param = m_lctx.mk_local_decl(m_ngen, binding_name(type), binding_domain(type), binding_info(type));
expr param = mk_local_decl_for(m_lctx, type);
m_params.push_back(param);
type = instantiate(binding_body(type), param);
} else {
if (!tc(m_lctx).is_def_eq(binding_domain(type), get_param_type(i)))
if (!is_def_eq(m_lctx, binding_domain(type), get_param_type(i)))
throw kernel_exception(m_env, "parameters of all inductive datatypes must match");
type = instantiate(binding_body(type), m_params[i]);
}
@ -230,24 +242,24 @@ public:
/** \brief Return `some(d_idx)` iff `t` is a recursive argument, `d_idx` is the index of the
recursive inductive datatype. Otherwise, return `none`. */
optional<unsigned> is_rec_argument(local_ctx lctx, expr t) {
t = tc(lctx).whnf(t);
t = whnf(lctx, t);
while (is_pi(t)) {
expr local = lctx.mk_local_decl(m_ngen, binding_name(t), binding_domain(t), binding_info(t));
t = tc(lctx).whnf(instantiate(binding_body(t), local));
expr local = mk_local_decl_for(lctx, t);
t = whnf(lctx, instantiate(binding_body(t), local));
}
return is_valid_ind_app(t);
}
/** \brief Check if \c t contains only positive occurrences of the inductive datatypes being declared. */
void check_positivity(local_ctx lctx, expr t, name const & cnstr_name, int arg_idx) {
t = tc(lctx).whnf(t);
t = whnf(lctx, t);
if (!has_ind_occ(t)) {
// nonrecursive argument
} else if (is_pi(t)) {
if (has_ind_occ(binding_domain(t)))
throw kernel_exception(m_env, sstream() << "arg #" << (arg_idx + 1) << " of '" << cnstr_name << "' "
"has a non positive occurrence of the datatypes being declared");
expr local = lctx.mk_local_decl(m_ngen, binding_name(t), binding_domain(t), binding_info(t));
expr local = mk_local_decl_for(lctx, t);
check_positivity(lctx, instantiate(binding_body(t), local), cnstr_name, arg_idx);
} else if (is_valid_ind_app(t)) {
// recursive argument
@ -272,7 +284,7 @@ public:
local_ctx lctx = m_lctx;
while (is_pi(t)) {
if (i < m_nparams) {
if (!tc(lctx).is_def_eq(binding_domain(t), get_param_type(i)))
if (!is_def_eq(lctx, binding_domain(t), get_param_type(i)))
throw kernel_exception(m_env, sstream() << "arg #" << (i + 1) << " of '" << n << "' "
<< "does not match inductive datatypes parameters'");
t = instantiate(binding_body(t), m_params[i]);
@ -287,7 +299,7 @@ public:
}
if (!m_is_meta)
check_positivity(lctx, binding_domain(t), n, i);
expr local = lctx.mk_local_decl(m_ngen, binding_name(t), binding_domain(t), binding_info(t));
expr local = mk_local_decl_for(lctx, t);
t = instantiate(binding_body(t), local);
}
i++;
@ -346,7 +358,7 @@ public:
buffer<expr> to_check; /* Arguments that we must check if occur in the result type */
local_ctx lctx;
while (is_pi(type)) {
expr fvar = lctx.mk_local_decl(m_ngen, binding_name(type), binding_domain(type), binding_info(type));
expr fvar = mk_local_decl_for(lctx, type);
if (i >= m_nparams) {
expr s = tc(lctx).ensure_type(binding_domain(type));
if (!is_zero(sort_level(s))) {
@ -381,7 +393,6 @@ public:
}
m_elim_level = mk_univ_param(u);
}
// std::cout << ">> elim_level: " << m_elim_level << "\n";
}
void init_K_target() {
@ -408,6 +419,162 @@ public:
}
}
/** \brief Given `t` of the form `I As is` where `I` is one of the inductive datatypes being defined,
As are the global parameters, and is the actual indices provided to it.
Return the index of `I`, and store is in the argument `indices`. */
unsigned get_I_indices(expr const & t, buffer<expr> & indices) {
optional<unsigned> r = is_valid_ind_app(t);
lean_assert(r);
buffer<expr> all_args;
get_app_args(t, all_args);
for (unsigned i = m_nparams; i < all_args.size(); i++)
indices.push_back(all_args[i]);
return *r;
}
/** \brief Populate m_rec_infos. */
void mk_rec_infos() {
unsigned d_idx = 0;
/* First, populate the fields, m_C, m_indices, m_lctx, m_major */
for (inductive_type const & ind_type : m_ind_types) {
rec_info info;
expr t = ind_type.get_type();
unsigned i = 0;
while (is_pi(t)) {
if (i < m_nparams) {
t = instantiate(binding_body(t), m_params[i]);
} else {
expr idx = mk_local_decl_for(m_lctx, t);
info.m_indices.push_back(idx);
t = instantiate(binding_body(t), idx);
}
i++;
}
info.m_major = mk_local_decl(m_lctx, "t",
mk_app(mk_app(m_ind_cnsts[d_idx], m_params), info.m_indices));
expr C_ty = mk_sort(m_elim_level);
C_ty = m_lctx.mk_pi(info.m_major, C_ty);
C_ty = m_lctx.mk_pi(info.m_indices, C_ty);
name C_name("C");
if (m_ind_types.size() > 1)
C_name = name(C_name).append_after(d_idx+1);
info.m_C = mk_local_decl(m_lctx, C_name, C_ty);
m_rec_infos.push_back(info);
d_idx++;
}
/* First, populate the field m_minors */
unsigned minor_idx = 1;
d_idx = 0;
for (inductive_type const & ind_type : m_ind_types) {
for (constructor const & cnstr : ind_type.get_cnstrs()) {
buffer<expr> b_u; // nonrec and rec args;
buffer<expr> u; // rec args
buffer<expr> v; // inductive args
expr t = constructor_type(cnstr);
unsigned i = 0;
while (is_pi(t)) {
if (i < m_nparams) {
t = instantiate(binding_body(t), m_params[i]);
} else {
expr l = mk_local_decl_for(m_lctx, t);
b_u.push_back(l);
if (is_rec_argument(m_lctx, binding_domain(t)))
u.push_back(l);
t = instantiate(binding_body(t), l);
}
i++;
}
buffer<expr> it_indices;
unsigned it_idx = get_I_indices(t, it_indices);
expr C_app = mk_app(m_rec_infos[it_idx].m_C, it_indices);
expr intro_app = mk_app(mk_app(mk_constant(constructor_name(cnstr), m_levels), m_params), b_u);
C_app = mk_app(C_app, intro_app);
/* populate v using u */
for (unsigned i = 0; i < u.size(); i++) {
expr u_i = u[i];
expr u_i_ty = whnf(m_lctx, infer_type(m_lctx, u_i));
buffer<expr> xs;
while (is_pi(u_i_ty)) {
expr x = mk_local_decl_for(m_lctx, u_i_ty);
xs.push_back(x);
u_i_ty = whnf(m_lctx, instantiate(binding_body(u_i_ty), x));
}
buffer<expr> it_indices;
unsigned it_idx = get_I_indices(u_i_ty, it_indices);
expr C_app = mk_app(m_rec_infos[it_idx].m_C, it_indices);
expr u_app = mk_app(u_i, xs);
C_app = mk_app(C_app, u_app);
expr v_i_ty = m_lctx.mk_pi(xs, C_app);
expr v_i = mk_local_decl(m_lctx, name("v").append_after(i), v_i_ty, binder_info());
v.push_back(v_i);
}
expr minor_ty = m_lctx.mk_pi(b_u, m_lctx.mk_pi(v, C_app));
expr minor = mk_local_decl(m_lctx, name("m").append_after(minor_idx), minor_ty);
m_rec_infos[d_idx].m_minors.push_back(minor);
minor_idx++;
}
d_idx++;
}
}
/** \brief Return the levels for the recursor. */
levels get_rec_level_params() {
if (is_param(m_elim_level))
return levels(m_elim_level, m_levels);
else
return m_levels;
}
/** \brief Return the level parameter names for the recursor. */
names get_rec_level_param_names() {
if (is_param(m_elim_level))
return level_param_names(param_id(m_elim_level), m_lparams);
else
return m_lparams;
}
/** \brief Declare recursors. */
void declare_recursors() {
names all = get_all_names();
for (unsigned d_idx = 0; d_idx < m_ind_types.size(); d_idx++) {
rec_info const & info = m_rec_infos[d_idx];
expr C_app = mk_app(mk_app(info.m_C, info.m_indices), info.m_major);
expr rec_ty = m_lctx.mk_pi(info.m_major, C_app);
rec_ty = m_lctx.mk_pi(info.m_indices, rec_ty);
/* Add minor premises */
unsigned nminors = 0;
unsigned i = m_ind_types.size();
while (i > 0) {
--i;
unsigned j = m_rec_infos[i].m_minors.size();
while (j > 0) {
--j;
rec_ty = m_lctx.mk_pi(m_rec_infos[i].m_minors[j], rec_ty);
nminors++;
}
}
/* Add type formers (aka motives) */
unsigned nmotives = 0;
i = m_ind_types.size();
while (i > 0) {
--i;
rec_ty = m_lctx.mk_pi(m_rec_infos[i].m_C, rec_ty);
nmotives++;
}
rec_ty = m_lctx.mk_pi(m_params, rec_ty);
rec_ty = infer_implicit(rec_ty, true /* strict */);
/*
TODO(Leo): gen reduction rule
*/
recursor_rules rules;
name rec_name = mk_rec_name(m_ind_types[d_idx].get_name());
names rec_lparams = get_rec_level_param_names();
m_env.add_core(constant_info(recursor_val(rec_name, rec_lparams, rec_ty, all,
m_nparams, m_nindices[d_idx], nmotives, nminors,
rules, m_K_target, m_is_meta)));
}
}
environment operator()() {
m_env.check_duplicated_univ_params(m_lparams);
check_inductive_types();
@ -416,6 +583,8 @@ public:
declare_constructors();
init_elim_level();
init_K_target();
mk_rec_infos();
declare_recursors();
return m_env;
}
};

View file

@ -95,6 +95,29 @@ inline object_ref mk_cnstr(unsigned tag, object_ref const & o1, object_ref const
return object_ref(r);
}
inline object_ref mk_cnstr(unsigned tag, object_ref const & o1, object_ref const & o2, object_ref const & o3, object_ref const & o4, object_ref const & o5, object_ref const & o6, unsigned scalar_sz = 0) {
object * r = alloc_cnstr(tag, 6, scalar_sz);
cnstr_set_obj(r, 0, o1.raw()); inc(o1.raw());
cnstr_set_obj(r, 1, o2.raw()); inc(o2.raw());
cnstr_set_obj(r, 2, o3.raw()); inc(o3.raw());
cnstr_set_obj(r, 3, o4.raw()); inc(o4.raw());
cnstr_set_obj(r, 4, o5.raw()); inc(o5.raw());
cnstr_set_obj(r, 5, o6.raw()); inc(o6.raw());
return object_ref(r);
}
inline object_ref mk_cnstr(unsigned tag, object_ref const & o1, object_ref const & o2, object_ref const & o3, object_ref const & o4, object_ref const & o5, object_ref const & o6, object_ref const & o7, unsigned scalar_sz = 0) {
object * r = alloc_cnstr(tag, 6, scalar_sz);
cnstr_set_obj(r, 0, o1.raw()); inc(o1.raw());
cnstr_set_obj(r, 1, o2.raw()); inc(o2.raw());
cnstr_set_obj(r, 2, o3.raw()); inc(o3.raw());
cnstr_set_obj(r, 3, o4.raw()); inc(o4.raw());
cnstr_set_obj(r, 4, o5.raw()); inc(o5.raw());
cnstr_set_obj(r, 5, o6.raw()); inc(o6.raw());
cnstr_set_obj(r, 6, o7.raw()); inc(o7.raw());
return object_ref(r);
}
/* The following definition is a low level hack that relies on the fact that sizeof(object_ref) == sizeof(object *). */
inline object_ref const & cnstr_obj_ref(object * o, unsigned i) {
static_assert(sizeof(object_ref) == sizeof(object *), "unexpected object_ref size"); // NOLINT