feat(library/equations_compiler/structural_rec): better support for structural recursion (based on brec_on)
For example, before this commit, structural_rec would not support the
function to_nat defined below.
```
set_option new_elaborator true
inductive foo : bool → Type
| Z : foo ff
| O : foo ff → foo tt
| E : foo tt → foo ff
definition to_nat : ∀ {b}, foo b → nat
| .ff Z := 0
| .tt (O n) := to_nat n + 1
| .ff (E n) := to_nat n + 1
```
This commit is contained in:
parent
b08af16d5f
commit
f1f45cc2b7
1 changed files with 99 additions and 31 deletions
|
|
@ -144,7 +144,10 @@ struct structural_rec_fn {
|
|||
return depends_on_any(e, locals.as_buffer().size(), locals.as_buffer().data());
|
||||
}
|
||||
|
||||
bool check_arg_type(unpack_eqns const & ues, unsigned arg_idx) {
|
||||
/* Return true iff argument arg_idx is a candidate for structural recursion.
|
||||
If the argument type is an indexed family, we store the position of the
|
||||
indices (in the function being defined) in the buffer idx_positions.*/
|
||||
bool check_arg_type(unpack_eqns const & ues, unsigned arg_idx, buffer<unsigned> & idx_positions) {
|
||||
type_context::tmp_locals locals(m_ctx);
|
||||
/* We can only use structural recursion on arg_idx IF
|
||||
1- Type is an inductive datatype with support for the brec_on construction.
|
||||
|
|
@ -174,25 +177,76 @@ struct structural_rec_fn {
|
|||
}
|
||||
unsigned nindices = *inductive::get_num_indices(m_ctx.env(), const_name(I));
|
||||
if (nindices > 0) {
|
||||
trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used "
|
||||
<< "for '" << fn << "' because the inductive type '" << I << "' is an indexed family\n "
|
||||
<< arg_type << "\n";);
|
||||
return false;
|
||||
lean_assert(I_args.size() >= nindices);
|
||||
for (unsigned i = I_args.size() - nindices; i < I_args.size(); i++) {
|
||||
expr const & idx = I_args[i];
|
||||
if (!is_local(idx)) {
|
||||
trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used "
|
||||
<< "for '" << fn << "' because the inductive type '" << I << "' is an indexed family, "
|
||||
<< "and index #" << (i+1) << " is not a variable\n "
|
||||
<< arg_type << "\n";);
|
||||
return false;
|
||||
}
|
||||
/* Index must be an argument of the function being defined */
|
||||
unsigned idx_pos = 0;
|
||||
buffer<expr> const & xs = locals.as_buffer();
|
||||
for (; idx_pos < xs.size(); idx_pos++) {
|
||||
expr const & x = xs[idx_pos];
|
||||
if (mlocal_name(x) == mlocal_name(idx)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (idx_pos == xs.size()) {
|
||||
trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used "
|
||||
<< "for '" << fn << "' because the inductive type '" << I << "' is an indexed family, "
|
||||
<< "and index #" << (i+1) << " is not an argument of the function being defined\n "
|
||||
<< arg_type << "\n";);
|
||||
return false;
|
||||
}
|
||||
/* Index can only depend on other indices in the function being defined. */
|
||||
expr idx_type = m_ctx.infer(idx);
|
||||
for (unsigned j = 0; j < idx_pos; j++) {
|
||||
bool j_is_not_index = std::find(idx_positions.begin(), idx_positions.end(), j) == idx_positions.end();
|
||||
if (j_is_not_index && depends_on(idx_type, xs[j])) {
|
||||
trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used "
|
||||
<< "for '" << fn << "' because the inductive type '" << I << "' is an indexed family, "
|
||||
<< "and index #" << (i+1) << " depends on argument #" << (j+1) << " of '" << fn << "' "
|
||||
<< "which is not an index of the inductive datatype\n "
|
||||
<< arg_type << "\n";);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
idx_positions.push_back(idx_pos);
|
||||
/* Each index can only occur once */
|
||||
for (unsigned j = 0; j < i; j++) {
|
||||
expr const & prev_idx = I_args[j];
|
||||
if (mlocal_name(prev_idx) == mlocal_name(idx)) {
|
||||
trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used "
|
||||
<< "for '" << fn << "' because the inductive type '" << I << "' is an indexed family, "
|
||||
<< "and index #" << (i+1) << " and #" << (j+1) << " must be different variables\n "
|
||||
<< arg_type << "\n";);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (depends_on_locals(arg_type, locals)) {
|
||||
trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used "
|
||||
<< "for '" << fn << "' because type parameter depends on previous arguments\n "
|
||||
<< arg_type << "\n";);
|
||||
return false;
|
||||
for (unsigned i = 0; i < I_args.size() - nindices; i++) {
|
||||
if (depends_on_locals(I_args[i], locals)) {
|
||||
trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used "
|
||||
<< "for '" << fn << "' because type parameter depends on previous arguments\n "
|
||||
<< arg_type << "\n";);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
optional<unsigned> find_rec_arg(unpack_eqns const & ues) {
|
||||
optional<unsigned> find_rec_arg(unpack_eqns const & ues, buffer<unsigned> & idx_positions) {
|
||||
buffer<expr> const & eqns = ues.get_eqns_of(0);
|
||||
unsigned arity = ues.get_arity_of(0);
|
||||
for (unsigned i = 0; i < arity; i++) {
|
||||
if (check_arg_type(ues, i)) {
|
||||
idx_positions.clear();
|
||||
if (check_arg_type(ues, i, idx_positions)) {
|
||||
bool ok = true;
|
||||
for (expr const & eqn : eqns) {
|
||||
if (!check_eq(eqn, i)) {
|
||||
|
|
@ -207,37 +261,43 @@ struct structural_rec_fn {
|
|||
}
|
||||
|
||||
/* Return the type of the new function, and the type of the motive for below/brec_on */
|
||||
pair<expr, expr> mk_new_fn_motive_types(unpack_eqns const & ues, unsigned arg_idx) {
|
||||
pair<expr, expr> mk_new_fn_motive_types(unpack_eqns const & ues, unsigned arg_idx,
|
||||
buffer<unsigned> const & idx_positions) {
|
||||
type_context::tmp_locals locals(m_ctx);
|
||||
expr fn = ues.get_fn(0);
|
||||
expr fn_type = m_ctx.infer(fn);
|
||||
unsigned arity = ues.get_arity_of(0);
|
||||
expr rec_arg;
|
||||
buffer<expr> other_args;
|
||||
buffer<expr> idx_args;
|
||||
for (unsigned i = 0; i < arity; i++) {
|
||||
fn_type = m_ctx.whnf(fn_type);
|
||||
if (!is_pi(fn_type)) throw_ill_formed_eqns();
|
||||
expr arg = locals.push_local_from_binding(fn_type);
|
||||
if (i == arg_idx) {
|
||||
rec_arg = arg;
|
||||
} else if (std::find(idx_positions.begin(), idx_positions.end(), i) != idx_positions.end()) {
|
||||
idx_args.push_back(arg);
|
||||
} else {
|
||||
other_args.push_back(arg);
|
||||
}
|
||||
fn_type = instantiate(binding_body(fn_type), arg);
|
||||
}
|
||||
buffer<expr> I_args;
|
||||
expr I = get_app_args(m_ctx.infer(rec_arg), I_args);
|
||||
buffer<expr> I_params;
|
||||
expr I = get_app_args(m_ctx.infer(rec_arg), I_params);
|
||||
unsigned nindices = idx_positions.size();
|
||||
I_params.shrink(I_params.size() - nindices);
|
||||
expr motive = m_ctx.mk_pi(other_args, fn_type);
|
||||
level u = get_level(m_ctx, motive);
|
||||
motive = m_ctx.mk_lambda(rec_arg, motive);
|
||||
motive = m_ctx.mk_lambda(idx_args, m_ctx.mk_lambda(rec_arg, motive));
|
||||
lean_assert(is_constant(I));
|
||||
buffer<level> below_lvls;
|
||||
below_lvls.push_back(u);
|
||||
for (level const & v : const_levels(I))
|
||||
below_lvls.push_back(v);
|
||||
expr below = mk_app(mk_constant(name(const_name(I), "below"), to_list(below_lvls)), I_args);
|
||||
expr below = mk_app(mk_constant(name(const_name(I), "below"), to_list(below_lvls)), I_params);
|
||||
expr motive_type = binding_domain(m_ctx.relaxed_whnf(m_ctx.infer(below)));
|
||||
below = mk_app(below, motive, rec_arg);
|
||||
below = mk_app(mk_app(mk_app(below, motive), idx_args), rec_arg);
|
||||
locals.push_local("_F", below);
|
||||
return mk_pair(locals.mk_pi(fn_type), motive_type);
|
||||
}
|
||||
|
|
@ -245,15 +305,16 @@ struct structural_rec_fn {
|
|||
struct elim_rec_apps_failed {};
|
||||
|
||||
struct elim_rec_apps_fn : public replace_visitor_with_tc {
|
||||
expr m_fn;
|
||||
unsigned m_arg_idx;
|
||||
expr m_F;
|
||||
expr m_C;
|
||||
expr m_fn;
|
||||
unsigned m_arg_idx;
|
||||
buffer<unsigned> const & m_idx_positions;
|
||||
expr m_F;
|
||||
expr m_C;
|
||||
|
||||
elim_rec_apps_fn(type_context & ctx, expr const & fn,
|
||||
unsigned arg_idx, expr const & F, expr const & C):
|
||||
unsigned arg_idx, buffer<unsigned> const & idx_positions, expr const & F, expr const & C):
|
||||
replace_visitor_with_tc(ctx),
|
||||
m_fn(fn), m_arg_idx(arg_idx), m_F(F), m_C(C) {}
|
||||
m_fn(fn), m_arg_idx(arg_idx), m_idx_positions(idx_positions), m_F(F), m_C(C) {}
|
||||
|
||||
/** \brief Retrieve result for \c a from the below dictionary \c d. \c d is a term made of products,
|
||||
and m_C (the abstract local). */
|
||||
|
|
@ -284,19 +345,24 @@ struct structural_rec_fn {
|
|||
}
|
||||
}
|
||||
|
||||
bool is_index_pos(unsigned idx) const {
|
||||
return std::find(m_idx_positions.begin(), m_idx_positions.end(), idx) != m_idx_positions.end();
|
||||
}
|
||||
|
||||
expr elim(buffer<expr> const & args, tag g) {
|
||||
/* Replace motives with abstract one m_C.
|
||||
We use the abstract motive m_C as "marker". */
|
||||
buffer<expr> below_args;
|
||||
expr const & below_cnst = get_app_args(m_ctx.infer(m_F), below_args);
|
||||
below_args[below_args.size() - 2] = m_C;
|
||||
unsigned nindices = m_idx_positions.size();
|
||||
below_args[below_args.size() - 1 - 1 /* major */ - nindices] = m_C;
|
||||
expr abst_below = mk_app(below_cnst, below_args);
|
||||
expr below_dict = m_ctx.whnf(abst_below);
|
||||
expr rec_arg = m_ctx.whnf(args[m_arg_idx]);
|
||||
if (optional<expr> b = to_below(below_dict, rec_arg, m_F)) {
|
||||
expr r = *b;
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
if (i != m_arg_idx)
|
||||
if (i != m_arg_idx && !is_index_pos(i))
|
||||
r = mk_app(r, args[i], g);
|
||||
}
|
||||
return r;
|
||||
|
|
@ -329,7 +395,8 @@ struct structural_rec_fn {
|
|||
}
|
||||
};
|
||||
|
||||
void update_eqs(unpack_eqns & ues, expr const & fn, expr const & new_fn, unsigned arg_idx, expr const & motive_type) {
|
||||
void update_eqs(unpack_eqns & ues, expr const & fn, expr const & new_fn,
|
||||
unsigned arg_idx, buffer<unsigned> const & idx_positions, expr const & motive_type) {
|
||||
/* C is a temporary "abstract" motive, we use it to access the "brec_on dictionary".
|
||||
The "brec_on dictionary is an element of type below, and it is the last argument of the new function. */
|
||||
expr C = mk_local(mk_fresh_name(), "_C", motive_type, binder_info());
|
||||
|
|
@ -346,7 +413,7 @@ struct structural_rec_fn {
|
|||
expr F = ue.add_var(binding_name(type), binding_domain(type));
|
||||
new_lhs = mk_app(new_lhs, F);
|
||||
ue.lhs() = new_lhs;
|
||||
ue.rhs() = elim_rec_apps_fn(m_ctx, fn, arg_idx, F, C)(rhs);
|
||||
ue.rhs() = elim_rec_apps_fn(m_ctx, fn, arg_idx, idx_positions, F, C)(rhs);
|
||||
eqn = ue.repack();
|
||||
}
|
||||
}
|
||||
|
|
@ -360,21 +427,22 @@ struct structural_rec_fn {
|
|||
tout() << "\n";);
|
||||
return none_expr();
|
||||
}
|
||||
optional<unsigned> r = find_rec_arg(ues);
|
||||
buffer<unsigned> idx_positions;
|
||||
optional<unsigned> r = find_rec_arg(ues, idx_positions);
|
||||
if (!r) return none_expr();
|
||||
arg_idx = *r;
|
||||
expr fn = ues.get_fn(0);
|
||||
trace_struct(tout() << "using structural recursion on argument #" << (arg_idx+1) <<
|
||||
" for '" << fn << "'\n";);
|
||||
expr new_fn_type, motive_type;
|
||||
std::tie(new_fn_type, motive_type) = mk_new_fn_motive_types(ues, arg_idx);
|
||||
std::tie(new_fn_type, motive_type) = mk_new_fn_motive_types(ues, arg_idx, idx_positions);
|
||||
trace_struct(
|
||||
tout() << "\n";
|
||||
tout() << "new function type: " << new_fn_type << "\n";
|
||||
tout() << "motive type: " << motive_type << "\n";);
|
||||
expr new_fn = ues.update_fn_type(0, new_fn_type);
|
||||
try {
|
||||
update_eqs(ues, fn, new_fn, arg_idx, motive_type);
|
||||
update_eqs(ues, fn, new_fn, arg_idx, idx_positions, motive_type);
|
||||
} catch (elim_rec_apps_failed &) {
|
||||
trace_struct(tout() << "failed to compile equations/match using structural recursion, "
|
||||
<< "when creating new set of equations\n";);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue