feat(library/equations_compiler/structural_rec): better support for structural recursion (based on brec_on)

For example, before this commit, structural_rec would not support the
function to_nat defined below.

```
set_option new_elaborator true

inductive foo : bool → Type
| Z  : foo ff
| O  : foo ff → foo tt
| E  : foo tt → foo ff

definition to_nat : ∀ {b}, foo b → nat
| .ff Z     := 0
| .tt (O n) := to_nat n + 1
| .ff (E n) := to_nat n + 1
```
This commit is contained in:
Leonardo de Moura 2016-08-29 10:51:09 -07:00
parent b08af16d5f
commit f1f45cc2b7

View file

@ -144,7 +144,10 @@ struct structural_rec_fn {
return depends_on_any(e, locals.as_buffer().size(), locals.as_buffer().data());
}
bool check_arg_type(unpack_eqns const & ues, unsigned arg_idx) {
/* Return true iff argument arg_idx is a candidate for structural recursion.
If the argument type is an indexed family, we store the position of the
indices (in the function being defined) in the buffer idx_positions.*/
bool check_arg_type(unpack_eqns const & ues, unsigned arg_idx, buffer<unsigned> & idx_positions) {
type_context::tmp_locals locals(m_ctx);
/* We can only use structural recursion on arg_idx IF
1- Type is an inductive datatype with support for the brec_on construction.
@ -174,25 +177,76 @@ struct structural_rec_fn {
}
unsigned nindices = *inductive::get_num_indices(m_ctx.env(), const_name(I));
if (nindices > 0) {
trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used "
<< "for '" << fn << "' because the inductive type '" << I << "' is an indexed family\n "
<< arg_type << "\n";);
return false;
lean_assert(I_args.size() >= nindices);
for (unsigned i = I_args.size() - nindices; i < I_args.size(); i++) {
expr const & idx = I_args[i];
if (!is_local(idx)) {
trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used "
<< "for '" << fn << "' because the inductive type '" << I << "' is an indexed family, "
<< "and index #" << (i+1) << " is not a variable\n "
<< arg_type << "\n";);
return false;
}
/* Index must be an argument of the function being defined */
unsigned idx_pos = 0;
buffer<expr> const & xs = locals.as_buffer();
for (; idx_pos < xs.size(); idx_pos++) {
expr const & x = xs[idx_pos];
if (mlocal_name(x) == mlocal_name(idx)) {
break;
}
}
if (idx_pos == xs.size()) {
trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used "
<< "for '" << fn << "' because the inductive type '" << I << "' is an indexed family, "
<< "and index #" << (i+1) << " is not an argument of the function being defined\n "
<< arg_type << "\n";);
return false;
}
/* Index can only depend on other indices in the function being defined. */
expr idx_type = m_ctx.infer(idx);
for (unsigned j = 0; j < idx_pos; j++) {
bool j_is_not_index = std::find(idx_positions.begin(), idx_positions.end(), j) == idx_positions.end();
if (j_is_not_index && depends_on(idx_type, xs[j])) {
trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used "
<< "for '" << fn << "' because the inductive type '" << I << "' is an indexed family, "
<< "and index #" << (i+1) << " depends on argument #" << (j+1) << " of '" << fn << "' "
<< "which is not an index of the inductive datatype\n "
<< arg_type << "\n";);
return false;
}
}
idx_positions.push_back(idx_pos);
/* Each index can only occur once */
for (unsigned j = 0; j < i; j++) {
expr const & prev_idx = I_args[j];
if (mlocal_name(prev_idx) == mlocal_name(idx)) {
trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used "
<< "for '" << fn << "' because the inductive type '" << I << "' is an indexed family, "
<< "and index #" << (i+1) << " and #" << (j+1) << " must be different variables\n "
<< arg_type << "\n";);
return false;
}
}
}
}
if (depends_on_locals(arg_type, locals)) {
trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used "
<< "for '" << fn << "' because type parameter depends on previous arguments\n "
<< arg_type << "\n";);
return false;
for (unsigned i = 0; i < I_args.size() - nindices; i++) {
if (depends_on_locals(I_args[i], locals)) {
trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used "
<< "for '" << fn << "' because type parameter depends on previous arguments\n "
<< arg_type << "\n";);
return false;
}
}
return true;
}
optional<unsigned> find_rec_arg(unpack_eqns const & ues) {
optional<unsigned> find_rec_arg(unpack_eqns const & ues, buffer<unsigned> & idx_positions) {
buffer<expr> const & eqns = ues.get_eqns_of(0);
unsigned arity = ues.get_arity_of(0);
for (unsigned i = 0; i < arity; i++) {
if (check_arg_type(ues, i)) {
idx_positions.clear();
if (check_arg_type(ues, i, idx_positions)) {
bool ok = true;
for (expr const & eqn : eqns) {
if (!check_eq(eqn, i)) {
@ -207,37 +261,43 @@ struct structural_rec_fn {
}
/* Return the type of the new function, and the type of the motive for below/brec_on */
pair<expr, expr> mk_new_fn_motive_types(unpack_eqns const & ues, unsigned arg_idx) {
pair<expr, expr> mk_new_fn_motive_types(unpack_eqns const & ues, unsigned arg_idx,
buffer<unsigned> const & idx_positions) {
type_context::tmp_locals locals(m_ctx);
expr fn = ues.get_fn(0);
expr fn_type = m_ctx.infer(fn);
unsigned arity = ues.get_arity_of(0);
expr rec_arg;
buffer<expr> other_args;
buffer<expr> idx_args;
for (unsigned i = 0; i < arity; i++) {
fn_type = m_ctx.whnf(fn_type);
if (!is_pi(fn_type)) throw_ill_formed_eqns();
expr arg = locals.push_local_from_binding(fn_type);
if (i == arg_idx) {
rec_arg = arg;
} else if (std::find(idx_positions.begin(), idx_positions.end(), i) != idx_positions.end()) {
idx_args.push_back(arg);
} else {
other_args.push_back(arg);
}
fn_type = instantiate(binding_body(fn_type), arg);
}
buffer<expr> I_args;
expr I = get_app_args(m_ctx.infer(rec_arg), I_args);
buffer<expr> I_params;
expr I = get_app_args(m_ctx.infer(rec_arg), I_params);
unsigned nindices = idx_positions.size();
I_params.shrink(I_params.size() - nindices);
expr motive = m_ctx.mk_pi(other_args, fn_type);
level u = get_level(m_ctx, motive);
motive = m_ctx.mk_lambda(rec_arg, motive);
motive = m_ctx.mk_lambda(idx_args, m_ctx.mk_lambda(rec_arg, motive));
lean_assert(is_constant(I));
buffer<level> below_lvls;
below_lvls.push_back(u);
for (level const & v : const_levels(I))
below_lvls.push_back(v);
expr below = mk_app(mk_constant(name(const_name(I), "below"), to_list(below_lvls)), I_args);
expr below = mk_app(mk_constant(name(const_name(I), "below"), to_list(below_lvls)), I_params);
expr motive_type = binding_domain(m_ctx.relaxed_whnf(m_ctx.infer(below)));
below = mk_app(below, motive, rec_arg);
below = mk_app(mk_app(mk_app(below, motive), idx_args), rec_arg);
locals.push_local("_F", below);
return mk_pair(locals.mk_pi(fn_type), motive_type);
}
@ -245,15 +305,16 @@ struct structural_rec_fn {
struct elim_rec_apps_failed {};
struct elim_rec_apps_fn : public replace_visitor_with_tc {
expr m_fn;
unsigned m_arg_idx;
expr m_F;
expr m_C;
expr m_fn;
unsigned m_arg_idx;
buffer<unsigned> const & m_idx_positions;
expr m_F;
expr m_C;
elim_rec_apps_fn(type_context & ctx, expr const & fn,
unsigned arg_idx, expr const & F, expr const & C):
unsigned arg_idx, buffer<unsigned> const & idx_positions, expr const & F, expr const & C):
replace_visitor_with_tc(ctx),
m_fn(fn), m_arg_idx(arg_idx), m_F(F), m_C(C) {}
m_fn(fn), m_arg_idx(arg_idx), m_idx_positions(idx_positions), m_F(F), m_C(C) {}
/** \brief Retrieve result for \c a from the below dictionary \c d. \c d is a term made of products,
and m_C (the abstract local). */
@ -284,19 +345,24 @@ struct structural_rec_fn {
}
}
bool is_index_pos(unsigned idx) const {
return std::find(m_idx_positions.begin(), m_idx_positions.end(), idx) != m_idx_positions.end();
}
expr elim(buffer<expr> const & args, tag g) {
/* Replace motives with abstract one m_C.
We use the abstract motive m_C as "marker". */
buffer<expr> below_args;
expr const & below_cnst = get_app_args(m_ctx.infer(m_F), below_args);
below_args[below_args.size() - 2] = m_C;
unsigned nindices = m_idx_positions.size();
below_args[below_args.size() - 1 - 1 /* major */ - nindices] = m_C;
expr abst_below = mk_app(below_cnst, below_args);
expr below_dict = m_ctx.whnf(abst_below);
expr rec_arg = m_ctx.whnf(args[m_arg_idx]);
if (optional<expr> b = to_below(below_dict, rec_arg, m_F)) {
expr r = *b;
for (unsigned i = 0; i < args.size(); i++) {
if (i != m_arg_idx)
if (i != m_arg_idx && !is_index_pos(i))
r = mk_app(r, args[i], g);
}
return r;
@ -329,7 +395,8 @@ struct structural_rec_fn {
}
};
void update_eqs(unpack_eqns & ues, expr const & fn, expr const & new_fn, unsigned arg_idx, expr const & motive_type) {
void update_eqs(unpack_eqns & ues, expr const & fn, expr const & new_fn,
unsigned arg_idx, buffer<unsigned> const & idx_positions, expr const & motive_type) {
/* C is a temporary "abstract" motive, we use it to access the "brec_on dictionary".
The "brec_on dictionary is an element of type below, and it is the last argument of the new function. */
expr C = mk_local(mk_fresh_name(), "_C", motive_type, binder_info());
@ -346,7 +413,7 @@ struct structural_rec_fn {
expr F = ue.add_var(binding_name(type), binding_domain(type));
new_lhs = mk_app(new_lhs, F);
ue.lhs() = new_lhs;
ue.rhs() = elim_rec_apps_fn(m_ctx, fn, arg_idx, F, C)(rhs);
ue.rhs() = elim_rec_apps_fn(m_ctx, fn, arg_idx, idx_positions, F, C)(rhs);
eqn = ue.repack();
}
}
@ -360,21 +427,22 @@ struct structural_rec_fn {
tout() << "\n";);
return none_expr();
}
optional<unsigned> r = find_rec_arg(ues);
buffer<unsigned> idx_positions;
optional<unsigned> r = find_rec_arg(ues, idx_positions);
if (!r) return none_expr();
arg_idx = *r;
expr fn = ues.get_fn(0);
trace_struct(tout() << "using structural recursion on argument #" << (arg_idx+1) <<
" for '" << fn << "'\n";);
expr new_fn_type, motive_type;
std::tie(new_fn_type, motive_type) = mk_new_fn_motive_types(ues, arg_idx);
std::tie(new_fn_type, motive_type) = mk_new_fn_motive_types(ues, arg_idx, idx_positions);
trace_struct(
tout() << "\n";
tout() << "new function type: " << new_fn_type << "\n";
tout() << "motive type: " << motive_type << "\n";);
expr new_fn = ues.update_fn_type(0, new_fn_type);
try {
update_eqs(ues, fn, new_fn, arg_idx, motive_type);
update_eqs(ues, fn, new_fn, arg_idx, idx_positions, motive_type);
} catch (elim_rec_apps_failed &) {
trace_struct(tout() << "failed to compile equations/match using structural recursion, "
<< "when creating new set of equations\n";);