From f1f45cc2b7ae786d5285ddba44d2a4fdaeb31f37 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 29 Aug 2016 10:51:09 -0700 Subject: [PATCH] feat(library/equations_compiler/structural_rec): better support for structural recursion (based on brec_on) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For example, before this commit, structural_rec would not support the function to_nat defined below. ``` set_option new_elaborator true inductive foo : bool → Type | Z : foo ff | O : foo ff → foo tt | E : foo tt → foo ff definition to_nat : ∀ {b}, foo b → nat | .ff Z := 0 | .tt (O n) := to_nat n + 1 | .ff (E n) := to_nat n + 1 ``` --- .../equations_compiler/structural_rec.cpp | 130 +++++++++++++----- 1 file changed, 99 insertions(+), 31 deletions(-) diff --git a/src/library/equations_compiler/structural_rec.cpp b/src/library/equations_compiler/structural_rec.cpp index 98fa8039ae..d3d78b3fe2 100644 --- a/src/library/equations_compiler/structural_rec.cpp +++ b/src/library/equations_compiler/structural_rec.cpp @@ -144,7 +144,10 @@ struct structural_rec_fn { return depends_on_any(e, locals.as_buffer().size(), locals.as_buffer().data()); } - bool check_arg_type(unpack_eqns const & ues, unsigned arg_idx) { + /* Return true iff argument arg_idx is a candidate for structural recursion. + If the argument type is an indexed family, we store the position of the + indices (in the function being defined) in the buffer idx_positions.*/ + bool check_arg_type(unpack_eqns const & ues, unsigned arg_idx, buffer & idx_positions) { type_context::tmp_locals locals(m_ctx); /* We can only use structural recursion on arg_idx IF 1- Type is an inductive datatype with support for the brec_on construction. @@ -174,25 +177,76 @@ struct structural_rec_fn { } unsigned nindices = *inductive::get_num_indices(m_ctx.env(), const_name(I)); if (nindices > 0) { - trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used " - << "for '" << fn << "' because the inductive type '" << I << "' is an indexed family\n " - << arg_type << "\n";); - return false; + lean_assert(I_args.size() >= nindices); + for (unsigned i = I_args.size() - nindices; i < I_args.size(); i++) { + expr const & idx = I_args[i]; + if (!is_local(idx)) { + trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used " + << "for '" << fn << "' because the inductive type '" << I << "' is an indexed family, " + << "and index #" << (i+1) << " is not a variable\n " + << arg_type << "\n";); + return false; + } + /* Index must be an argument of the function being defined */ + unsigned idx_pos = 0; + buffer const & xs = locals.as_buffer(); + for (; idx_pos < xs.size(); idx_pos++) { + expr const & x = xs[idx_pos]; + if (mlocal_name(x) == mlocal_name(idx)) { + break; + } + } + if (idx_pos == xs.size()) { + trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used " + << "for '" << fn << "' because the inductive type '" << I << "' is an indexed family, " + << "and index #" << (i+1) << " is not an argument of the function being defined\n " + << arg_type << "\n";); + return false; + } + /* Index can only depend on other indices in the function being defined. */ + expr idx_type = m_ctx.infer(idx); + for (unsigned j = 0; j < idx_pos; j++) { + bool j_is_not_index = std::find(idx_positions.begin(), idx_positions.end(), j) == idx_positions.end(); + if (j_is_not_index && depends_on(idx_type, xs[j])) { + trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used " + << "for '" << fn << "' because the inductive type '" << I << "' is an indexed family, " + << "and index #" << (i+1) << " depends on argument #" << (j+1) << " of '" << fn << "' " + << "which is not an index of the inductive datatype\n " + << arg_type << "\n";); + return false; + } + } + idx_positions.push_back(idx_pos); + /* Each index can only occur once */ + for (unsigned j = 0; j < i; j++) { + expr const & prev_idx = I_args[j]; + if (mlocal_name(prev_idx) == mlocal_name(idx)) { + trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used " + << "for '" << fn << "' because the inductive type '" << I << "' is an indexed family, " + << "and index #" << (i+1) << " and #" << (j+1) << " must be different variables\n " + << arg_type << "\n";); + return false; + } + } + } } - if (depends_on_locals(arg_type, locals)) { - trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used " - << "for '" << fn << "' because type parameter depends on previous arguments\n " - << arg_type << "\n";); - return false; + for (unsigned i = 0; i < I_args.size() - nindices; i++) { + if (depends_on_locals(I_args[i], locals)) { + trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used " + << "for '" << fn << "' because type parameter depends on previous arguments\n " + << arg_type << "\n";); + return false; + } } return true; } - optional find_rec_arg(unpack_eqns const & ues) { + optional find_rec_arg(unpack_eqns const & ues, buffer & idx_positions) { buffer const & eqns = ues.get_eqns_of(0); unsigned arity = ues.get_arity_of(0); for (unsigned i = 0; i < arity; i++) { - if (check_arg_type(ues, i)) { + idx_positions.clear(); + if (check_arg_type(ues, i, idx_positions)) { bool ok = true; for (expr const & eqn : eqns) { if (!check_eq(eqn, i)) { @@ -207,37 +261,43 @@ struct structural_rec_fn { } /* Return the type of the new function, and the type of the motive for below/brec_on */ - pair mk_new_fn_motive_types(unpack_eqns const & ues, unsigned arg_idx) { + pair mk_new_fn_motive_types(unpack_eqns const & ues, unsigned arg_idx, + buffer const & idx_positions) { type_context::tmp_locals locals(m_ctx); expr fn = ues.get_fn(0); expr fn_type = m_ctx.infer(fn); unsigned arity = ues.get_arity_of(0); expr rec_arg; buffer other_args; + buffer idx_args; for (unsigned i = 0; i < arity; i++) { fn_type = m_ctx.whnf(fn_type); if (!is_pi(fn_type)) throw_ill_formed_eqns(); expr arg = locals.push_local_from_binding(fn_type); if (i == arg_idx) { rec_arg = arg; + } else if (std::find(idx_positions.begin(), idx_positions.end(), i) != idx_positions.end()) { + idx_args.push_back(arg); } else { other_args.push_back(arg); } fn_type = instantiate(binding_body(fn_type), arg); } - buffer I_args; - expr I = get_app_args(m_ctx.infer(rec_arg), I_args); + buffer I_params; + expr I = get_app_args(m_ctx.infer(rec_arg), I_params); + unsigned nindices = idx_positions.size(); + I_params.shrink(I_params.size() - nindices); expr motive = m_ctx.mk_pi(other_args, fn_type); level u = get_level(m_ctx, motive); - motive = m_ctx.mk_lambda(rec_arg, motive); + motive = m_ctx.mk_lambda(idx_args, m_ctx.mk_lambda(rec_arg, motive)); lean_assert(is_constant(I)); buffer below_lvls; below_lvls.push_back(u); for (level const & v : const_levels(I)) below_lvls.push_back(v); - expr below = mk_app(mk_constant(name(const_name(I), "below"), to_list(below_lvls)), I_args); + expr below = mk_app(mk_constant(name(const_name(I), "below"), to_list(below_lvls)), I_params); expr motive_type = binding_domain(m_ctx.relaxed_whnf(m_ctx.infer(below))); - below = mk_app(below, motive, rec_arg); + below = mk_app(mk_app(mk_app(below, motive), idx_args), rec_arg); locals.push_local("_F", below); return mk_pair(locals.mk_pi(fn_type), motive_type); } @@ -245,15 +305,16 @@ struct structural_rec_fn { struct elim_rec_apps_failed {}; struct elim_rec_apps_fn : public replace_visitor_with_tc { - expr m_fn; - unsigned m_arg_idx; - expr m_F; - expr m_C; + expr m_fn; + unsigned m_arg_idx; + buffer const & m_idx_positions; + expr m_F; + expr m_C; elim_rec_apps_fn(type_context & ctx, expr const & fn, - unsigned arg_idx, expr const & F, expr const & C): + unsigned arg_idx, buffer const & idx_positions, expr const & F, expr const & C): replace_visitor_with_tc(ctx), - m_fn(fn), m_arg_idx(arg_idx), m_F(F), m_C(C) {} + m_fn(fn), m_arg_idx(arg_idx), m_idx_positions(idx_positions), m_F(F), m_C(C) {} /** \brief Retrieve result for \c a from the below dictionary \c d. \c d is a term made of products, and m_C (the abstract local). */ @@ -284,19 +345,24 @@ struct structural_rec_fn { } } + bool is_index_pos(unsigned idx) const { + return std::find(m_idx_positions.begin(), m_idx_positions.end(), idx) != m_idx_positions.end(); + } + expr elim(buffer const & args, tag g) { /* Replace motives with abstract one m_C. We use the abstract motive m_C as "marker". */ buffer below_args; expr const & below_cnst = get_app_args(m_ctx.infer(m_F), below_args); - below_args[below_args.size() - 2] = m_C; + unsigned nindices = m_idx_positions.size(); + below_args[below_args.size() - 1 - 1 /* major */ - nindices] = m_C; expr abst_below = mk_app(below_cnst, below_args); expr below_dict = m_ctx.whnf(abst_below); expr rec_arg = m_ctx.whnf(args[m_arg_idx]); if (optional b = to_below(below_dict, rec_arg, m_F)) { expr r = *b; for (unsigned i = 0; i < args.size(); i++) { - if (i != m_arg_idx) + if (i != m_arg_idx && !is_index_pos(i)) r = mk_app(r, args[i], g); } return r; @@ -329,7 +395,8 @@ struct structural_rec_fn { } }; - void update_eqs(unpack_eqns & ues, expr const & fn, expr const & new_fn, unsigned arg_idx, expr const & motive_type) { + void update_eqs(unpack_eqns & ues, expr const & fn, expr const & new_fn, + unsigned arg_idx, buffer const & idx_positions, expr const & motive_type) { /* C is a temporary "abstract" motive, we use it to access the "brec_on dictionary". The "brec_on dictionary is an element of type below, and it is the last argument of the new function. */ expr C = mk_local(mk_fresh_name(), "_C", motive_type, binder_info()); @@ -346,7 +413,7 @@ struct structural_rec_fn { expr F = ue.add_var(binding_name(type), binding_domain(type)); new_lhs = mk_app(new_lhs, F); ue.lhs() = new_lhs; - ue.rhs() = elim_rec_apps_fn(m_ctx, fn, arg_idx, F, C)(rhs); + ue.rhs() = elim_rec_apps_fn(m_ctx, fn, arg_idx, idx_positions, F, C)(rhs); eqn = ue.repack(); } } @@ -360,21 +427,22 @@ struct structural_rec_fn { tout() << "\n";); return none_expr(); } - optional r = find_rec_arg(ues); + buffer idx_positions; + optional r = find_rec_arg(ues, idx_positions); if (!r) return none_expr(); arg_idx = *r; expr fn = ues.get_fn(0); trace_struct(tout() << "using structural recursion on argument #" << (arg_idx+1) << " for '" << fn << "'\n";); expr new_fn_type, motive_type; - std::tie(new_fn_type, motive_type) = mk_new_fn_motive_types(ues, arg_idx); + std::tie(new_fn_type, motive_type) = mk_new_fn_motive_types(ues, arg_idx, idx_positions); trace_struct( tout() << "\n"; tout() << "new function type: " << new_fn_type << "\n"; tout() << "motive type: " << motive_type << "\n";); expr new_fn = ues.update_fn_type(0, new_fn_type); try { - update_eqs(ues, fn, new_fn, arg_idx, motive_type); + update_eqs(ues, fn, new_fn, arg_idx, idx_positions, motive_type); } catch (elim_rec_apps_failed &) { trace_struct(tout() << "failed to compile equations/match using structural recursion, " << "when creating new set of equations\n";);