diff --git a/src/frontends/lean/definition_cmds.cpp b/src/frontends/lean/definition_cmds.cpp index 1e18899a42..546b4b7584 100644 --- a/src/frontends/lean/definition_cmds.cpp +++ b/src/frontends/lean/definition_cmds.cpp @@ -153,15 +153,21 @@ environment mutual_definition_cmd_core(parser & p, def_cmd_kind kind, decl_modif buffer fns, params; declaration_info_scope scope(p, kind, modifiers); expr val = parse_mutual_definition(p, lp_names, fns, params); + + // skip elaboration of definitions during reparsing + if (p.get_break_at_pos()) + return p.env(); + bool recover_from_errors = true; elaborator elab(p.env(), p.get_options(), get_namespace(p.env()) + local_pp_name(fns[0]), metavar_context(), local_context(), recover_from_errors); buffer new_params; elaborate_params(elab, params, new_params); val = replace_locals_preserving_pos_info(val, params, new_params); - // TODO(Leo) - for (auto p : new_params) { tout() << ">> " << p << " : " << mlocal_type(p) << "\n"; } + val = elab.elaborate(val); + tout() << val << "\n"; + // TODO(Leo) return p.env(); } diff --git a/src/library/equations_compiler/CMakeLists.txt b/src/library/equations_compiler/CMakeLists.txt index 9abaf771dc..350be06f93 100644 --- a/src/library/equations_compiler/CMakeLists.txt +++ b/src/library/equations_compiler/CMakeLists.txt @@ -2,4 +2,4 @@ add_library(equations_compiler OBJECT equations.cpp util.cpp pack_domain.cpp structural_rec.cpp unbounded_rec.cpp elim_match.cpp compiler.cpp wf_rec.cpp - init_module.cpp) + pack_mutual.cpp init_module.cpp) diff --git a/src/library/equations_compiler/init_module.cpp b/src/library/equations_compiler/init_module.cpp index f4ae69f370..6a47cb832b 100644 --- a/src/library/equations_compiler/init_module.cpp +++ b/src/library/equations_compiler/init_module.cpp @@ -9,12 +9,14 @@ Author: Leonardo de Moura #include "library/equations_compiler/structural_rec.h" #include "library/equations_compiler/wf_rec.h" #include "library/equations_compiler/elim_match.h" +#include "library/equations_compiler/pack_mutual.h" #include "library/equations_compiler/compiler.h" namespace lean{ void initialize_equations_compiler_module() { initialize_eqn_compiler_util(); initialize_equations(); + initialize_pack_mutual(); initialize_structural_rec(); initialize_wf_rec(); initialize_elim_match(); @@ -26,6 +28,7 @@ void finalize_equations_compiler_module() { finalize_elim_match(); finalize_structural_rec(); finalize_wf_rec(); + finalize_pack_mutual(); finalize_equations(); finalize_eqn_compiler_util(); } diff --git a/src/library/equations_compiler/pack_mutual.cpp b/src/library/equations_compiler/pack_mutual.cpp new file mode 100644 index 0000000000..ff868f984e --- /dev/null +++ b/src/library/equations_compiler/pack_mutual.cpp @@ -0,0 +1,214 @@ +/* +Copyright (c) 2017 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include "kernel/instantiate.h" +#include "library/constants.h" +#include "library/trace.h" +#include "library/app_builder.h" +#include "library/type_context.h" +#include "library/locals.h" +#include "library/replace_visitor_with_tc.h" +#include "library/equations_compiler/equations.h" +#include "library/equations_compiler/util.h" + +namespace lean { +#define trace_debug_mutual(Code) lean_trace(name({"debug", "eqn_compiler", "mutual"}), scope_trace_env _scope(m_ctx.env(), m_ctx); Code) + +struct pack_mutual_fn { + type_context & m_ctx; + + pack_mutual_fn(type_context & ctx):m_ctx(ctx) {} + + expr mk_new_domain(buffer const & domains) { + unsigned i = domains.size(); + lean_assert(i > 1); + --i; + expr r = domains[i]; + while (i > 0) { + --i; + r = mk_app(m_ctx, get_psum_name(), domains[i], r); + } + return r; + } + + expr mk_new_codomain(expr const & x, unsigned i, buffer const & codomains, level codomains_lvl) { + if (i == codomains.size() - 1) { + return instantiate(codomains[i], x); + } else { + expr x_type = m_ctx.relaxed_whnf(m_ctx.infer(x)); + buffer args; + expr psum = get_app_args(x_type, args); + lean_assert(const_name(psum) == get_psum_name()); + lean_assert(args.size() == 2); + levels psum_cases_on_lvls(mk_succ(codomains_lvl), const_levels(psum)); + expr cases_on = mk_constant(get_psum_cases_on_name(), psum_cases_on_lvls); + /* Add parameters */ + cases_on = mk_app(cases_on, args); + /* Add motive */ + expr motive = mk_sort(codomains_lvl); + cases_on = mk_app(cases_on, m_ctx.mk_lambda(x, motive)); + /* Add major */ + cases_on = mk_app(cases_on, x); + /* Add minors */ + type_context::tmp_locals locals(m_ctx); + expr y_1 = locals.push_local("_s", args[0]); + expr m_1 = m_ctx.mk_lambda(y_1, instantiate(codomains[i], y_1)); + expr y_2 = locals.push_local("_s", args[1]); + expr m_2 = mk_new_codomain(y_2, i+1, codomains, codomains_lvl); + m_2 = m_ctx.mk_lambda(y_2, m_2); + return mk_app(cases_on, m_1, m_2); + } + } + + struct replace_fns : public replace_visitor_with_tc { + unpack_eqns const & m_ues; + expr m_new_fn; + expr m_new_domain; + + replace_fns(type_context & ctx, unpack_eqns const & ues, expr const & new_fn): + replace_visitor_with_tc(ctx), + m_ues(ues), + m_new_fn(new_fn) { + expr new_fn_type = m_ctx.relaxed_whnf(m_ctx.infer(m_new_fn)); + lean_assert(is_pi(new_fn_type)); + m_new_domain = m_ctx.relaxed_whnf(binding_domain(new_fn_type)); + lean_assert(is_constant(get_app_fn(m_new_domain)), get_psum_name()); + } + + optional get_fidx(expr const & fn) const { + if (!is_local(fn)) + return optional(); + for (unsigned fidx = 0; fidx < m_ues.get_num_fns(); fidx++) { + if (mlocal_name(m_ues.get_fn(fidx)) == mlocal_name(fn)) + return optional(fidx); + } + return optional(); + } + + expr mk_new_arg(expr const & e, unsigned fidx, unsigned i, expr psum_type) { + if (i == m_ues.get_num_fns() - 1) { + return e; + } else { + psum_type = m_ctx.relaxed_whnf(psum_type); + buffer args; + get_app_args(psum_type, args); + lean_assert(args.size() == 2); + if (i == fidx) { + return mk_app(m_ctx, get_psum_inl_name(), args[0], args[1], e); + } else { + expr r = mk_new_arg(e, fidx, i+1, args[1]); + return mk_app(m_ctx, get_psum_inr_name(), args[0], args[1], r); + } + } + } + + expr mk_new_arg(expr const & e, unsigned fidx) { + return mk_new_arg(e, fidx, 0, m_new_domain); + } + + virtual expr visit_app(expr const & e) override { + if (optional fidx = get_fidx(app_fn(e))) { + expr arg = visit(app_arg(e)); + expr new_arg = mk_new_arg(arg, *fidx); + return mk_app(m_new_fn, new_arg); + } else { + return replace_visitor_with_tc::visit_app(e); + } + } + + virtual expr visit_local(expr const & e) override { + if (get_fidx(e)) { + throw generic_exception(e, "unexpected occurrence of recursive function\n"); + } else { + return e; + } + } + }; + + expr operator()(expr const & e) { + unpack_eqns ues(m_ctx, e); + if (ues.get_num_fns() == 1) + return e; + /* Given + f_1 : Pi (x : A_1), B_1 x + ... + f_n : Pi (x : A_n), B_n x + + create a function with type + f : Pi (x : psum A_1 ... (psum A_{n-1} A_n)), psum.cases_on x (fun y, B_1 y) (... (fun y, B_n y) ...) + + remark: this module assumes the B_i's are in the same universe. */ + type_context::tmp_locals locals(m_ctx); + buffer domains; + buffer codomains; + level codomains_lvl; + name new_fn_name; + for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) { + expr const & fn = ues.get_fn(fidx); + new_fn_name = new_fn_name + local_pp_name(fn); + lean_assert(ues.get_arity_of(fidx) == 1); + expr fn_type = m_ctx.relaxed_whnf(m_ctx.infer(fn)); + lean_assert(is_pi(fn_type)); + domains.push_back(binding_domain(fn_type)); + expr y = locals.push_local("_s", binding_domain(fn_type)); + expr c = instantiate(binding_body(fn_type), y); + level c_lvl = get_level(m_ctx, c); + if (fidx == 0) { + codomains_lvl = c_lvl; + } else if (!m_ctx.is_def_eq(c_lvl, codomains_lvl)) { + throw generic_exception(e, "invalid mutual definition, result types must be in the same universe"); + } + codomains.push_back(binding_body(fn_type)); + } + + expr new_domain = mk_new_domain(domains); + expr x = locals.push_local("_x", new_domain); + expr new_codomain = mk_new_codomain(x, 0, codomains, codomains_lvl); + expr new_fn_type = m_ctx.mk_pi(x, new_codomain); + expr new_fn = locals.push_local(new_fn_name, new_fn_type); + + trace_debug_mutual(tout() << "new function " << new_fn_name << " : " << new_fn_type << "\n";); + + equations_header new_header = get_equations_header(e); + new_header.m_num_fns = 1; + + replace_fns replacer(m_ctx, ues, new_fn); + + buffer new_eqns; + for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) { + buffer const & eqns = ues.get_eqns_of(fidx); + for (expr const & eqn : eqns) { + unpack_eqn ue(m_ctx, eqn); + expr new_lhs = replacer(ue.lhs()); + expr new_rhs = replacer(ue.rhs()); + expr new_eqn = mk_equation(new_lhs, new_rhs, ue.ignore_if_unused()); + new_eqns.push_back(m_ctx.mk_lambda(new_fn, m_ctx.mk_lambda(ue.get_vars(), new_eqn))); + } + } + + expr result; + if (is_wf_equations(e)) + result = mk_equations(new_header, new_eqns.size(), new_eqns.data(), equations_wf_rel(e), equations_wf_proof(e)); + else + result = mk_equations(new_header, new_eqns.size(), new_eqns.data()); + + trace_debug_mutual(tout() << "result\n" << result << "\n";); + + return result; + } +}; + +expr pack_mutual(type_context & ctx, expr const & e) { + return pack_mutual_fn(ctx)(e); +} + +void initialize_pack_mutual() { + register_trace_class({"debug", "eqn_compiler", "mutual"}); +} + +void finalize_pack_mutual() { +} +} diff --git a/src/library/equations_compiler/pack_mutual.h b/src/library/equations_compiler/pack_mutual.h new file mode 100644 index 0000000000..75638df424 --- /dev/null +++ b/src/library/equations_compiler/pack_mutual.h @@ -0,0 +1,16 @@ +/* +Copyright (c) 2017 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include "library/type_context.h" +namespace lean { +/** \brief Create a new equations object containing a single function. + The functions must be unary. */ +expr pack_mutual(type_context & ctx, expr const & eqns); + +void initialize_pack_mutual(); +void finalize_pack_mutual(); +} diff --git a/src/library/equations_compiler/util.cpp b/src/library/equations_compiler/util.cpp index 82adaa18b7..85cf9b572d 100644 --- a/src/library/equations_compiler/util.cpp +++ b/src/library/equations_compiler/util.cpp @@ -165,6 +165,7 @@ unpack_eqn::unpack_eqn(type_context & ctx, expr const & eqn): m_nested_src = it; m_lhs = equation_lhs(it); m_rhs = equation_rhs(it); + m_ignore_if_unused = ignore_equation_if_unused(it); } expr unpack_eqn::add_var(name const & n, expr const & type) { diff --git a/src/library/equations_compiler/util.h b/src/library/equations_compiler/util.h index ecd058605c..a5056d2577 100644 --- a/src/library/equations_compiler/util.h +++ b/src/library/equations_compiler/util.h @@ -59,6 +59,7 @@ class unpack_eqn { expr m_nested_src; expr m_lhs; expr m_rhs; + bool m_ignore_if_unused; public: unpack_eqn(type_context & ctx, expr const & eqn); expr add_var(name const & n, expr const & type); @@ -66,6 +67,7 @@ public: expr & lhs() { return m_lhs; } expr & rhs() { return m_rhs; } expr const & get_nested_src() const { return m_nested_src; } + bool ignore_if_unused() const { return m_ignore_if_unused; } expr repack(); }; diff --git a/src/library/equations_compiler/wf_rec.cpp b/src/library/equations_compiler/wf_rec.cpp index f1449f0b84..f9911fc20e 100644 --- a/src/library/equations_compiler/wf_rec.cpp +++ b/src/library/equations_compiler/wf_rec.cpp @@ -13,6 +13,7 @@ Author: Leonardo de Moura #include "library/sorry.h" // remove after we add tactic for proving recursive calls are decreasing #include "library/replace_visitor_with_tc.h" #include "library/equations_compiler/pack_domain.h" +#include "library/equations_compiler/pack_mutual.h" #include "library/equations_compiler/elim_match.h" #include "library/equations_compiler/util.h" @@ -48,7 +49,18 @@ struct wf_rec_fn { expr pack_domain(expr const & eqns) { type_context ctx = mk_type_context(); - return ::lean::pack_domain(ctx, eqns); + expr r = ::lean::pack_domain(ctx, eqns); + m_env = ctx.env(); + m_mctx = ctx.mctx(); + return r; + } + + expr pack_mutual(expr const & eqns) { + type_context ctx = mk_type_context(); + expr r = ::lean::pack_mutual(ctx, eqns); + m_env = ctx.env(); + m_mctx = ctx.mctx(); + return r; } expr_pair mk_wf_relation(expr const & eqns) { @@ -196,8 +208,7 @@ struct wf_rec_fn { /* Make sure we have only one function */ equations_header const & header = get_equations_header(eqns); if (header.m_num_fns > 1) { - // TODO(Leo): combine functions - throw exception("support for mutual recursion has not been implemented yet"); + eqns = pack_mutual(eqns); } /* Retrieve well founded relation */ diff --git a/tmp/even_odd.lean b/tmp/even_odd.lean new file mode 100644 index 0000000000..13f2042fe5 --- /dev/null +++ b/tmp/even_odd.lean @@ -0,0 +1,27 @@ +import data.vector +open nat +universes u v +set_option pp.all true + +set_option trace.eqn_compiler.wf_rec true +set_option trace.debug.eqn_compiler.wf_rec true +set_option trace.debug.eqn_compiler.mutual true +set_option trace.eqn_compiler.elim_match true + + +mutual def even, odd +with even : nat → bool +| 0 := tt +| (a+1) := odd a +with odd : nat → bool +| 0 := ff +| (a+1) := even a + + +mutual def f, g {α : Type u} {β : Type v} (f : α → β) (p : α × β) +with f : Π n : nat, vector (α × β) n +| 0 := vector.nil +| (succ n) := vector.cons p $ (g n p.1).map (λ b, (p.1, b)) +with g : Π n : nat, α → vector β n +| 0 a := vector.nil +| (succ n) a := vector.cons p.2 $ (f n).map (λ p, p.2)