From 8b67480ceefe8fa3c847ff5cd10b68ceffc6f66d Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 15 Aug 2016 15:34:51 -0700 Subject: [PATCH] feat(library/equations_compiler): add step for handling structural recursion --- src/library/equations_compiler/CMakeLists.txt | 1 + src/library/equations_compiler/compiler.cpp | 21 +- src/library/equations_compiler/compiler.h | 2 + src/library/equations_compiler/equations.h | 14 +- .../equations_compiler/init_module.cpp | 6 + .../equations_compiler/pack_domain.cpp | 33 ++- .../equations_compiler/structural_rec.cpp | 269 ++++++++++++++++++ .../equations_compiler/structural_rec.h | 25 ++ src/library/equations_compiler/util.cpp | 103 ++++++- src/library/equations_compiler/util.h | 68 ++++- src/library/type_context.h | 7 + 11 files changed, 501 insertions(+), 48 deletions(-) create mode 100644 src/library/equations_compiler/structural_rec.cpp create mode 100644 src/library/equations_compiler/structural_rec.h diff --git a/src/library/equations_compiler/CMakeLists.txt b/src/library/equations_compiler/CMakeLists.txt index 39557a55e3..f853bb5079 100644 --- a/src/library/equations_compiler/CMakeLists.txt +++ b/src/library/equations_compiler/CMakeLists.txt @@ -1,5 +1,6 @@ add_library(equations_compiler OBJECT equations.cpp util.cpp pack_domain.cpp + structural_rec.cpp compiler.cpp init_module.cpp #LEGACY old_compiler.cpp old_goal.cpp old_inversion.cpp) diff --git a/src/library/equations_compiler/compiler.cpp b/src/library/equations_compiler/compiler.cpp index 98fa955b30..a1460af9c2 100644 --- a/src/library/equations_compiler/compiler.cpp +++ b/src/library/equations_compiler/compiler.cpp @@ -8,14 +8,27 @@ Author: Leonardo de Moura #include "library/equations_compiler/compiler.h" #include "library/equations_compiler/util.h" #include "library/equations_compiler/pack_domain.h" +#include "library/equations_compiler/structural_rec.h" namespace lean { +#define trace_compiler(Code) lean_trace("eqn_compiler", scope_trace_env _scope1(ctx->env(), ctx); Code) + expr compile_equations(environment const & env, options const & opts, metavar_context & mctx, local_context const & lctx, expr const & eqns) { - aux_type_context ctx(env, opts, mctx, lctx); - tout() << eqns << "\n"; - expr eqns1 = pack_domain(ctx.get(), eqns); - tout() << eqns1 << "\n"; + aux_type_context ctx(env, opts, mctx, lctx, transparency_mode::Semireducible); + trace_compiler(tout() << "compiling\n" << eqns << "\n";); + trace_compiler(tout() << "recursive: " << is_recursive_eqns(ctx, eqns) << "\n";); + + // expr eqns1 = pack_domain(ctx.get(), eqns); + // tout() << eqns1 << "\n"; + unsigned arg_idx; + optional eqns1 = try_structural_rec(ctx.get(), eqns, arg_idx); lean_unreachable(); } + +void initialize_compiler() { + register_trace_class("eqn_compiler"); +} +void finalize_compiler() { +} } diff --git a/src/library/equations_compiler/compiler.h b/src/library/equations_compiler/compiler.h index ac01d9fb0f..c80bf555b3 100644 --- a/src/library/equations_compiler/compiler.h +++ b/src/library/equations_compiler/compiler.h @@ -10,4 +10,6 @@ Author: Leonardo de Moura namespace lean { expr compile_equations(environment const & env, options const & opts, metavar_context & mctx, local_context const & lctx, expr const & eqns); +void initialize_compiler(); +void finalize_compiler(); } diff --git a/src/library/equations_compiler/equations.h b/src/library/equations_compiler/equations.h index e589fafea0..76f2800637 100644 --- a/src/library/equations_compiler/equations.h +++ b/src/library/equations_compiler/equations.h @@ -49,16 +49,9 @@ void to_equations(expr const & e, buffer & eqns); expr const & equations_wf_proof(expr const & e); expr const & equations_wf_rel(expr const & e); -/* TODO(Leo): delete the following versions */ -expr mk_equations(unsigned num_fns, unsigned num_eqs, expr const * eqs); -expr mk_equations(unsigned num_fns, unsigned num_eqs, expr const * eqs, expr const & R, expr const & Hwf); -/* End of delete ------------- */ - /** \brief Return true if \c e is an auxiliary macro used to store the result of mutually recursive declarations. For example, if a set of recursive equations is defining \c n mutually recursive functions, we wrap - the \c n resulting functions (and their types) with an \c equations_result macro. - - TODO(Leo): delete this after we implement the new equations compiler */ + the \c n resulting functions (and their types) with an \c equations_result macro. */ bool is_equations_result(expr const & e); expr mk_equations_result(unsigned n, expr const * rs); unsigned get_equations_result_size(expr const & e); @@ -66,4 +59,9 @@ expr const & get_equations_result(expr const & e, unsigned i); void initialize_equations(); void finalize_equations(); + +/* TODO(Leo): delete the following versions */ +expr mk_equations(unsigned num_fns, unsigned num_eqs, expr const * eqs); +expr mk_equations(unsigned num_fns, unsigned num_eqs, expr const * eqs, expr const & R, expr const & Hwf); +/* End of delete ------------- */ } diff --git a/src/library/equations_compiler/init_module.cpp b/src/library/equations_compiler/init_module.cpp index 9146240e1f..b67ab08cfa 100644 --- a/src/library/equations_compiler/init_module.cpp +++ b/src/library/equations_compiler/init_module.cpp @@ -5,13 +5,19 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include "library/equations_compiler/equations.h" +#include "library/equations_compiler/structural_rec.h" +#include "library/equations_compiler/compiler.h" namespace lean{ void initialize_equations_compiler_module() { initialize_equations(); + initialize_structural_rec(); + initialize_compiler(); } void finalize_equations_compiler_module() { + finalize_compiler(); + finalize_structural_rec(); finalize_equations(); } } diff --git a/src/library/equations_compiler/pack_domain.cpp b/src/library/equations_compiler/pack_domain.cpp index f8fe91dab6..054d0e0ea7 100644 --- a/src/library/equations_compiler/pack_domain.cpp +++ b/src/library/equations_compiler/pack_domain.cpp @@ -54,8 +54,8 @@ struct sigma_packer_fn { } class update_apps_fn : public replace_visitor_with_tc { - buffer const & m_old_fns; - equations_editor const & m_editor; + buffer const & m_old_fns; + unpack_eqns const & m_ues; optional get_fn_idx(expr const & fn) { if (!is_local(fn)) return optional(); @@ -89,9 +89,9 @@ struct sigma_packer_fn { expr const & fn = get_app_args(e, args); auto fnidx = get_fn_idx(fn); if (!fnidx) return replace_visitor_with_tc::visit_app(e); - expr new_fn = m_editor.get_fn(*fnidx); + expr new_fn = m_ues.get_fn(*fnidx); if (fn == new_fn) return replace_visitor_with_tc::visit_app(e); - unsigned arity = m_editor.get_arity(*fnidx); + unsigned arity = m_ues.get_arity_of(*fnidx); if (args.size() < arity) { expr new_e = m_ctx.eta_expand(e); if (!is_lambda(new_e)) throw_ill_formed_eqns(); @@ -105,33 +105,32 @@ struct sigma_packer_fn { } public: - update_apps_fn(type_context & ctx, buffer const & old_fns, equations_editor const & editor): - replace_visitor_with_tc(ctx), m_old_fns(old_fns), m_editor(editor) {} + update_apps_fn(type_context & ctx, buffer const & old_fns, unpack_eqns const & ues): + replace_visitor_with_tc(ctx), m_old_fns(old_fns), m_ues(ues) {} }; expr operator()(expr const & e) { - equations_editor editor; - editor.unpack(e); + unpack_eqns ues(m_ctx, e); buffer old_fns; bool modified = false; - for (unsigned fidx = 0; fidx < editor.get_num_fns(); fidx++) { - expr & fn = editor.get_fn(fidx); + for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) { + expr const & fn = ues.get_fn(fidx); old_fns.push_back(fn); - unsigned arity = editor.get_arity(fidx); + unsigned arity = ues.get_arity_of(fidx); if (arity > 1) { - expr new_type = pack_as_unary(mlocal_type(fn), arity); - fn = update_mlocal(fn, new_type); + expr new_type = pack_as_unary(m_ctx.infer(fn), arity); + ues.update_fn_type(fidx, new_type); modified = true; } } if (!modified) return e; - update_apps_fn updt(m_ctx, old_fns, editor); - for (unsigned fidx = 0; fidx < editor.get_num_fns(); fidx++) { - buffer & eqs = editor.get_eqs_of(fidx); + update_apps_fn updt(m_ctx, old_fns, ues); + for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) { + buffer & eqs = ues.get_eqns_of(fidx); for (expr & eq : eqs) eq = updt(eq); } - return editor.repack(); + return ues.repack(); } }; diff --git a/src/library/equations_compiler/structural_rec.cpp b/src/library/equations_compiler/structural_rec.cpp new file mode 100644 index 0000000000..a8bca6a0d3 --- /dev/null +++ b/src/library/equations_compiler/structural_rec.cpp @@ -0,0 +1,269 @@ +/* +Copyright (c) 2016 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include "kernel/instantiate.h" +#include "library/trace.h" +#include "library/locals.h" +#include "library/app_builder.h" +#include "library/equations_compiler/util.h" +#include "library/equations_compiler/structural_rec.h" + +namespace lean { +#define trace_struct(Code) lean_trace(name({"eqn_compiler", "structural_rec"}), scope_trace_env _scope1(m_ctx.env(), m_ctx); Code) + +struct structural_rec_fn { + type_context & m_ctx; + structural_rec_fn(type_context & ctx):m_ctx(ctx) {} + + /** \brief Auxiliary object for checking whether recursive application are + structurally smaller or not */ + struct check_rhs_fn { + type_context & m_ctx; + expr m_lhs; + expr m_fn; + expr m_pattern; + unsigned m_arg_idx; + + check_rhs_fn(type_context & ctx, expr const & lhs, expr const & fn, expr const & pattern, unsigned arg_idx): + m_ctx(ctx), m_lhs(lhs), m_fn(fn), m_pattern(pattern), m_arg_idx(arg_idx) {} + + bool is_constructor(expr const & e) const { + return static_cast(eqns_env_interface(m_ctx).is_constructor(e)); + } + + /** \brief Return true iff \c s is structurally smaller than \c t OR equal to \c t */ + bool is_le(expr const & s, expr const & t) { + return m_ctx.is_def_eq(s, t) || is_lt(s, t); + } + + /** Return true iff \c s is structurally smaller than \c t */ + bool is_lt(expr s, expr const & t) { + s = m_ctx.whnf(s); + if (is_app(s)) { + expr const & s_fn = get_app_fn(s); + if (!is_constructor(s_fn)) + return is_lt(s_fn, t); // f < t ==> s := f a_1 ... a_n < t + } + buffer t_args; + expr const & t_fn = get_app_args(t, t_args); + if (!is_constructor(t_fn)) + return false; + return std::any_of(t_args.begin(), t_args.end(), + [&](expr const & t_arg) { return is_le(s, t_arg); }); + } + + /** \brief Return true iff all recursive applications in \c e are structurally smaller than \c m_pattern. */ + bool check_rhs(expr const & e) { + switch (e.kind()) { + case expr_kind::Var: case expr_kind::Meta: + case expr_kind::Local: case expr_kind::Constant: + case expr_kind::Sort: + return true; + case expr_kind::Macro: + for (unsigned i = 0; i < macro_num_args(e); i++) + if (!check_rhs(macro_arg(e, i))) + return false; + return true; + case expr_kind::App: { + buffer args; + expr const & fn = get_app_args(e, args); + if (!check_rhs(fn)) + return false; + for (unsigned i = 0; i < args.size(); i++) + if (!check_rhs(args[i])) + return false; + if (is_local(fn) && mlocal_name(fn) == mlocal_name(m_fn)) { + /* recusive application */ + if (m_arg_idx < args.size()) { + expr const & arg = args[m_arg_idx]; + /* arg must be structurally smaller than m_pattern */ + if (!is_lt(arg, m_pattern)) { + trace_struct(tout() << "structural recursion on argument #" << (m_arg_idx+1) << " was not used " + << "for '" << m_fn << "'\nargument #" << (m_arg_idx+1) + << " in the application\n " + << e << "\nis not structurally smaller than the one occurring in " + << "the equation left-hand-side\n " + << m_lhs << "\n";); + return false; + } + } else { + /* function is not fully applied */ + trace_struct(tout() << "structural recursion on argument #" << (m_arg_idx+1) << " was not used " + << "for '" << m_fn << "' because of the partial application\n " + << e << "\n";); + return false; + } + } + return true; + } + case expr_kind::Let: + if (!check_rhs(let_value(e))) { + return false; + } else { + type_context::tmp_locals locals(m_ctx); + return check_rhs(instantiate(let_body(e), locals.push_local_from_let(e))); + } + case expr_kind::Lambda: + case expr_kind::Pi: + if (!check_rhs(binding_domain(e))) { + return false; + } else { + type_context::tmp_locals locals(m_ctx); + return check_rhs(instantiate(binding_body(e), locals.push_local_from_binding(e))); + } + } + lean_unreachable(); + } + + bool operator()(expr const & e) { + return check_rhs(e); + } + }; + + bool check_rhs(expr const & lhs, expr const & fn, expr pattern, unsigned arg_idx, expr const & rhs) { + pattern = m_ctx.whnf(pattern); + return check_rhs_fn(m_ctx, lhs, fn, pattern, arg_idx)(rhs); + } + + bool check_eq(expr const & eqn, unsigned arg_idx) { + unpack_eqn ue(m_ctx, eqn); + buffer args; + expr const & fn = get_app_args(ue.lhs(), args); + return check_rhs(ue.lhs(), fn, args[arg_idx], arg_idx, ue.rhs()); + } + + static bool depends_on_locals(expr const & e, type_context::tmp_locals const & locals) { + return depends_on_any(e, locals.as_buffer().size(), locals.as_buffer().data()); + } + + bool check_arg_type(unpack_eqns const & ues, unsigned arg_idx) { + type_context::tmp_locals locals(m_ctx); + /* We can only use structural recursion on arg_idx IF + 1- Type is an inductive datatype with support for the brec_on construction. + 2- Type parameters do not depend on other arguments of the function being defined. */ + expr fn = ues.get_fn(0); + expr fn_type = m_ctx.infer(fn); + for (unsigned i = 0; i < arg_idx; i++) { + fn_type = m_ctx.whnf(fn_type); + if (!is_pi(fn_type)) throw_ill_formed_eqns(); + fn_type = instantiate(binding_body(fn_type), locals.push_local_from_binding(fn_type)); + } + if (!is_pi(fn_type)) throw_ill_formed_eqns(); + expr arg_type = binding_domain(fn_type); + buffer I_args; + expr I = get_app_args(arg_type, I_args); + if (!eqns_env_interface(m_ctx).is_inductive(I)) { + trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used " + << "for '" << fn << "' because type is not inductive\n " + << arg_type << "\n";); + return false; + } + if (!m_ctx.env().find(name(const_name(I), "brec_on"))) { + trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used " + << "for '" << fn << "' because the inductive type '" << I << "' does have brec_on recursor\n " + << arg_type << "\n";); + return false; + } + unsigned nindices = eqns_env_interface(m_ctx).get_inductive_num_indices(const_name(I)); + if (nindices > 0) { + trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used " + << "for '" << fn << "' because the inductive type '" << I << "' is an indexed family\n " + << arg_type << "\n";); + return false; + } + if (depends_on_locals(arg_type, locals)) { + trace_struct(tout() << "structural recursion on argument #" << (arg_idx+1) << " was not used " + << "for '" << fn << "' because type parameter depends on previous arguments\n " + << arg_type << "\n";); + return false; + } + return true; + } + + optional find_rec_arg(unpack_eqns const & ues) { + buffer const & eqns = ues.get_eqns_of(0); + unsigned arity = ues.get_arity_of(0); + for (unsigned i = 0; i < arity; i++) { + if (check_arg_type(ues, i)) { + bool ok = true; + for (expr const & eqn : eqns) { + if (!check_eq(eqn, i)) { + ok = false; + break; + } + } + if (ok) return optional(i); + } + } + return optional(); + } + + expr mk_new_fn_type(unpack_eqns const & ues, unsigned arg_idx) { + type_context::tmp_locals locals(m_ctx); + expr fn = ues.get_fn(0); + expr fn_type = m_ctx.infer(fn); + unsigned arity = ues.get_arity_of(0); + expr rec_arg; + buffer other_args; + for (unsigned i = 0; i < arity; i++) { + fn_type = m_ctx.whnf(fn_type); + if (!is_pi(fn_type)) throw_ill_formed_eqns(); + expr arg = locals.push_local_from_binding(fn_type); + if (i == arg_idx) { + rec_arg = arg; + } else { + other_args.push_back(arg); + } + fn_type = instantiate(binding_body(fn_type), arg); + } + expr motive = m_ctx.mk_pi(other_args, fn_type); + level u = get_level(m_ctx, motive); + motive = m_ctx.mk_lambda(rec_arg, motive); + buffer I_args; + expr I = get_app_args(m_ctx.infer(rec_arg), I_args); + lean_assert(is_constant(I)); + buffer below_lvls; + below_lvls.push_back(u); + for (level const & v : const_levels(I)) + below_lvls.push_back(v); + expr below = mk_app(mk_constant(name(const_name(I), "below"), to_list(below_lvls)), motive, rec_arg); + locals.push_local("_F", below); + return locals.mk_pi(fn_type); + } + + optional operator()(expr const & e, unsigned & arg_idx) { + unpack_eqns ues(m_ctx, e); + if (ues.get_num_fns() != 1) { + trace_struct(tout() << "structural recursion is not supported for mutually recursive functions:"; + for (unsigned i = 0; i < ues.get_num_fns(); i++) + tout() << " " << ues.get_fn(i); + tout() << "\n";); + return none_expr(); + } + optional r = find_rec_arg(ues); + if (!r) return none_expr(); + arg_idx = *r; + trace_struct(tout() << "using structural recursion on argument #" << (arg_idx+1) << + " for '" << ues.get_fn(0) << "'\n";); + expr new_fn_type = mk_new_fn_type(ues, arg_idx); + + trace_struct(tout() << "new function type: " << new_fn_type << "\n";); + + // TODO(Leo) + + return some_expr(ues.repack()); + } +}; + +optional try_structural_rec(type_context & ctx, expr const & e, unsigned & arg_idx) { + return structural_rec_fn(ctx)(e, arg_idx); +} + +void initialize_structural_rec() { + register_trace_class({"eqn_compiler", "structural_rec"}); +} +void finalize_structural_rec() {} +} diff --git a/src/library/equations_compiler/structural_rec.h b/src/library/equations_compiler/structural_rec.h new file mode 100644 index 0000000000..e5b19a62fd --- /dev/null +++ b/src/library/equations_compiler/structural_rec.h @@ -0,0 +1,25 @@ +/* +Copyright (c) 2016 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include "library/type_context.h" +namespace lean { +/** \brief Try to eliminate "recursive calls" in the equations \c e by using brec_on's below. + If successful, a new set of (non-recursive) equations is produced. The new equations + have a new argument of type `I.below` and all "recursive calls" are replaced with it. + + The procedure fails when: + 1- \c e is defining more than one function + 2- None of the arguments is a primitive inductive datatype with support for brec_on + construction, where every recursive call is structurally smaller. + + \remark arg_idx is an ouput parameter. When successful, it contains the argument that + we should apply brec_on too. */ +optional try_structural_rec(type_context & ctx, expr const & e, unsigned & arg_idx); + +void initialize_structural_rec(); +void finalize_structural_rec(); +} diff --git a/src/library/equations_compiler/util.cpp b/src/library/equations_compiler/util.cpp index 9a188d201c..49cfde8bde 100644 --- a/src/library/equations_compiler/util.cpp +++ b/src/library/equations_compiler/util.cpp @@ -6,6 +6,8 @@ Author: Leonardo de Moura */ #include "kernel/instantiate.h" #include "kernel/abstract.h" +#include "kernel/find_fn.h" +#include "kernel/inductive/inductive.h" #include "library/equations_compiler/equations.h" #include "library/equations_compiler/util.h" @@ -36,10 +38,8 @@ static expr consume_fn_prefix(expr eq, buffer const & fns) { return instantiate_rev(eq, fns); } -void equations_editor::unpack(expr const & e) { - m_fns.clear(); - m_arity.clear(); - m_eqs.clear(); +unpack_eqns::unpack_eqns(type_context & ctx, expr const & e): + m_locals(ctx) { lean_assert(is_equations(e)); m_src = e; buffer eqs; @@ -49,8 +49,9 @@ void equations_editor::unpack(expr const & e) { lean_assert(eqs.size() > 0); expr eq = eqs[0]; for (unsigned i = 0; i < num_fns; i++) { - lean_assert(is_lambda(eq)); - m_fns.push_back(mk_local(binding_name(eq), binding_domain(eq))); + if (!is_lambda(eq)) throw_ill_formed_eqns(); + if (!closed(binding_domain(eq))) throw_ill_formed_eqns(); + m_fns.push_back(m_locals.push_local(binding_name(eq), binding_domain(eq))); eq = binding_body(eq); } /* Extract equations */ @@ -95,13 +96,99 @@ void equations_editor::unpack(expr const & e) { lean_assert(m_eqs.size() == m_fns.size()); } -expr equations_editor::repack() { +expr unpack_eqns::update_fn_type(unsigned fidx, expr const & type) { + expr new_fn = m_locals.push_local(local_pp_name(m_fns[fidx]), type); + m_fns[fidx] = new_fn; + return new_fn; +} + +expr unpack_eqns::repack() { buffer new_eqs; for (buffer const & fn_eqs : m_eqs) { for (expr const & eq : fn_eqs) { - new_eqs.push_back(Fun(m_fns, eq)); + new_eqs.push_back(m_locals.ctx().mk_lambda(m_fns, eq)); } } return update_equations(m_src, new_eqs); } + +unpack_eqn::unpack_eqn(type_context & ctx, expr const & eqn): + m_src(eqn), m_locals(ctx) { + expr it = eqn; + while (is_lambda(it)) { + expr d = instantiate_rev(binding_domain(it), m_locals.as_buffer().size(), m_locals.as_buffer().data()); + m_vars.push_back(m_locals.push_local(binding_name(it), d, binding_info(it))); + it = binding_body(it); + } + it = instantiate_rev(it, m_locals.as_buffer().size(), m_locals.as_buffer().data()); + if (!is_equation(it)) throw_ill_formed_eqns(); + m_nested_src = it; + m_lhs = equation_lhs(it); + m_rhs = equation_rhs(it); +} + +expr unpack_eqn::add_var(name const & n, expr const & type) { + m_modified_vars = true; + m_vars.push_back(m_locals.push_local(n, type)); + return m_vars.back(); +} + +expr unpack_eqn::repack() { + if (!m_modified_vars && + equation_lhs(m_nested_src) == m_lhs && + equation_rhs(m_nested_src) == m_rhs) return m_src; + expr new_eq = copy_tag(m_nested_src, mk_equation(m_lhs, m_rhs)); + return copy_tag(m_src, m_locals.ctx().mk_lambda(m_vars, new_eq)); +} + +bool eqns_env_interface::is_inductive(name const & n) const { + return static_cast(inductive::is_inductive_decl(m_env, n)); +} + +bool eqns_env_interface::is_inductive(expr const & e) const { + if (!is_constant(e)) return false; + return is_inductive(const_name(e)); +} + +optional eqns_env_interface::is_constructor(expr const & e) const { + if (!is_constant(e)) return optional(); + return inductive::is_intro_rule(m_env, const_name(e)); +} + +unsigned eqns_env_interface::get_inductive_num_params(name const & n) const { + lean_assert(is_inductive(n)); + return *inductive::get_num_params(m_env, n); +} + +unsigned eqns_env_interface::get_inductive_num_indices(name const & n) const { + lean_assert(is_inductive(n)); + return *inductive::get_num_indices(m_env, n); +} + +bool is_recursive_eqns(type_context & ctx, expr const & e) { + unpack_eqns ues(ctx, e); + for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) { + buffer const & eqns = ues.get_eqns_of(fidx); + for (expr const & eqn : eqns) { + expr it = eqn; + while (is_lambda(it)) { + it = binding_body(it); + } + if (!is_equation(it)) throw_ill_formed_eqns(); + expr const & rhs = equation_rhs(it); + if (find(rhs, [&](expr const & e, unsigned) { + if (is_local(e)) { + for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) { + if (mlocal_name(e) == mlocal_name(ues.get_fn(fidx))) + return true; + } + } + return false; + })) { + return true; + } + } + } + return false; +} } diff --git a/src/library/equations_compiler/util.h b/src/library/equations_compiler/util.h index d0af692b38..9debbb777f 100644 --- a/src/library/equations_compiler/util.h +++ b/src/library/equations_compiler/util.h @@ -7,6 +7,8 @@ Author: Leonardo de Moura #pragma once #include "library/type_context.h" namespace lean { +[[ noreturn ]] void throw_ill_formed_eqns(); + /** \brief Helper class for modifying/updating an equations-expression. \remark The equations macro is awkward to use since it is a leftover @@ -16,32 +18,76 @@ namespace lean { TODO(Leo): as soon as we remove the legacy code from Lean2, this class will be much simpler. */ -class equations_editor { - expr m_src; - buffer m_fns; +class unpack_eqns { + type_context::tmp_locals m_locals; + expr m_src; + buffer m_fns; /* m_arity[i] contains the number of arguments for each equation lhs for m_fns[i]. \remark m_arity.size() == m_fns.size(). \remark The information stored in this field is ignore by repack. */ - buffer m_arity; + buffer m_arity; /* m_eqns[i] are the equations for m_fns[i]. \remark m_eqs.size() == m_fns.size(). */ - buffer> m_eqs; + buffer> m_eqs; public: /** \brief Extract the data stored in the equations-expression \c e. \pre is_equations(e) */ - void unpack(expr const & e); + unpack_eqns(type_context & ctx, expr const & e); /** \brief Re-build an equations-expression using the information stored at m_fns and m_eqs. */ expr repack(); + /** Update the type of the function with the given idx. + \remark The equations are not updated. They still reference the old function. */ + expr update_fn_type(unsigned fidx, expr const & type); + unsigned get_num_fns() const { return m_fns.size(); } - expr & get_fn(unsigned fidx) { return m_fns[fidx]; } expr const & get_fn(unsigned fidx) const { return m_fns[fidx]; } - buffer & get_eqs_of(unsigned fidx) { return m_eqs[fidx]; } - buffer const & get_eqs_of(unsigned fidx) const { return m_eqs[fidx]; } - unsigned get_arity(unsigned fidx) const { return m_arity[fidx]; } + buffer & get_eqns_of(unsigned fidx) { return m_eqs[fidx]; } + buffer const & get_eqns_of(unsigned fidx) const { return m_eqs[fidx]; } + unsigned get_arity_of(unsigned fidx) const { return m_arity[fidx]; } }; -void throw_ill_formed_eqns(); +/** \brief Helper class for unpacking a single equation nested in a equations expression. */ +class unpack_eqn { + expr m_src; + type_context::tmp_locals m_locals; + bool m_modified_vars{false}; + buffer m_vars; + expr m_nested_src; + expr m_lhs; + expr m_rhs; +public: + unpack_eqn(type_context & ctx, expr const & eqn); + expr add_var(name const & n, expr const & type); + buffer const & get_vars() { return m_vars; } + expr & lhs() { return m_lhs; } + expr & rhs() { return m_rhs; } + expr repack(); +}; + +/** \brief Interface object for providing extra functionality + required by the equation compiler from the environment. + + For example, it abstracts the inductive datatype API. + So, if we add new forms of inductive datatype, we need + to change this class. */ +class eqns_env_interface { + environment m_env; +public: + eqns_env_interface(environment const & env):m_env(env) {} + eqns_env_interface(type_context const & ctx):m_env(ctx.env()) {} + + bool is_inductive(name const & n) const; + bool is_inductive(expr const & e) const; + optional is_constructor(expr const & e) const; + unsigned get_inductive_num_params(name const & n) const; + unsigned get_inductive_num_indices(name const & n) const; +}; + +/** \brief Return true iff \c e is recursive. That is, some equation + in the rhs has a reference to a function being defined by the + equations. */ +bool is_recursive_eqns(type_context & ctx, expr const & e); } diff --git a/src/library/type_context.h b/src/library/type_context.h index 793515085e..80c0e4e8de 100644 --- a/src/library/type_context.h +++ b/src/library/type_context.h @@ -471,6 +471,8 @@ public: tmp_locals(type_context & ctx):m_ctx(ctx) {} ~tmp_locals(); + type_context & ctx() { return m_ctx; } + expr push_local(name const & pp_name, expr const & type, binder_info const & bi = binder_info()) { expr r = m_ctx.push_local(pp_name, type, bi); m_locals.push_back(r); @@ -488,6 +490,11 @@ public: return push_local(binding_name(e), binding_domain(e), binding_info(e)); } + expr push_local_from_let(expr const & e) { + lean_assert(is_let(e)); + return push_let(let_name(e), let_type(e), let_value(e)); + } + unsigned size() const { return m_locals.size(); } expr const * data() const { return m_locals.data(); }