feat(library/equations_compiler): add pack_mutual

This step packs a collection of mutually recursive functions into a
single one. We use `psum` to combine the different domains, and
`psum.cases_on` to combine the codomains.
This commit is contained in:
Leonardo de Moura 2017-05-18 15:29:51 -07:00
parent 22d0dc197c
commit 789d4e148f
9 changed files with 286 additions and 6 deletions

View file

@ -153,15 +153,21 @@ environment mutual_definition_cmd_core(parser & p, def_cmd_kind kind, decl_modif
buffer<expr> fns, params;
declaration_info_scope scope(p, kind, modifiers);
expr val = parse_mutual_definition(p, lp_names, fns, params);
// skip elaboration of definitions during reparsing
if (p.get_break_at_pos())
return p.env();
bool recover_from_errors = true;
elaborator elab(p.env(), p.get_options(), get_namespace(p.env()) + local_pp_name(fns[0]), metavar_context(), local_context(), recover_from_errors);
buffer<expr> new_params;
elaborate_params(elab, params, new_params);
val = replace_locals_preserving_pos_info(val, params, new_params);
// TODO(Leo)
for (auto p : new_params) { tout() << ">> " << p << " : " << mlocal_type(p) << "\n"; }
val = elab.elaborate(val);
tout() << val << "\n";
// TODO(Leo)
return p.env();
}

View file

@ -2,4 +2,4 @@ add_library(equations_compiler OBJECT
equations.cpp util.cpp pack_domain.cpp
structural_rec.cpp unbounded_rec.cpp
elim_match.cpp compiler.cpp wf_rec.cpp
init_module.cpp)
pack_mutual.cpp init_module.cpp)

View file

@ -9,12 +9,14 @@ Author: Leonardo de Moura
#include "library/equations_compiler/structural_rec.h"
#include "library/equations_compiler/wf_rec.h"
#include "library/equations_compiler/elim_match.h"
#include "library/equations_compiler/pack_mutual.h"
#include "library/equations_compiler/compiler.h"
namespace lean{
void initialize_equations_compiler_module() {
initialize_eqn_compiler_util();
initialize_equations();
initialize_pack_mutual();
initialize_structural_rec();
initialize_wf_rec();
initialize_elim_match();
@ -26,6 +28,7 @@ void finalize_equations_compiler_module() {
finalize_elim_match();
finalize_structural_rec();
finalize_wf_rec();
finalize_pack_mutual();
finalize_equations();
finalize_eqn_compiler_util();
}

View file

@ -0,0 +1,214 @@
/*
Copyright (c) 2017 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "kernel/instantiate.h"
#include "library/constants.h"
#include "library/trace.h"
#include "library/app_builder.h"
#include "library/type_context.h"
#include "library/locals.h"
#include "library/replace_visitor_with_tc.h"
#include "library/equations_compiler/equations.h"
#include "library/equations_compiler/util.h"
namespace lean {
#define trace_debug_mutual(Code) lean_trace(name({"debug", "eqn_compiler", "mutual"}), scope_trace_env _scope(m_ctx.env(), m_ctx); Code)
struct pack_mutual_fn {
type_context & m_ctx;
pack_mutual_fn(type_context & ctx):m_ctx(ctx) {}
expr mk_new_domain(buffer<expr> const & domains) {
unsigned i = domains.size();
lean_assert(i > 1);
--i;
expr r = domains[i];
while (i > 0) {
--i;
r = mk_app(m_ctx, get_psum_name(), domains[i], r);
}
return r;
}
expr mk_new_codomain(expr const & x, unsigned i, buffer<expr> const & codomains, level codomains_lvl) {
if (i == codomains.size() - 1) {
return instantiate(codomains[i], x);
} else {
expr x_type = m_ctx.relaxed_whnf(m_ctx.infer(x));
buffer<expr> args;
expr psum = get_app_args(x_type, args);
lean_assert(const_name(psum) == get_psum_name());
lean_assert(args.size() == 2);
levels psum_cases_on_lvls(mk_succ(codomains_lvl), const_levels(psum));
expr cases_on = mk_constant(get_psum_cases_on_name(), psum_cases_on_lvls);
/* Add parameters */
cases_on = mk_app(cases_on, args);
/* Add motive */
expr motive = mk_sort(codomains_lvl);
cases_on = mk_app(cases_on, m_ctx.mk_lambda(x, motive));
/* Add major */
cases_on = mk_app(cases_on, x);
/* Add minors */
type_context::tmp_locals locals(m_ctx);
expr y_1 = locals.push_local("_s", args[0]);
expr m_1 = m_ctx.mk_lambda(y_1, instantiate(codomains[i], y_1));
expr y_2 = locals.push_local("_s", args[1]);
expr m_2 = mk_new_codomain(y_2, i+1, codomains, codomains_lvl);
m_2 = m_ctx.mk_lambda(y_2, m_2);
return mk_app(cases_on, m_1, m_2);
}
}
struct replace_fns : public replace_visitor_with_tc {
unpack_eqns const & m_ues;
expr m_new_fn;
expr m_new_domain;
replace_fns(type_context & ctx, unpack_eqns const & ues, expr const & new_fn):
replace_visitor_with_tc(ctx),
m_ues(ues),
m_new_fn(new_fn) {
expr new_fn_type = m_ctx.relaxed_whnf(m_ctx.infer(m_new_fn));
lean_assert(is_pi(new_fn_type));
m_new_domain = m_ctx.relaxed_whnf(binding_domain(new_fn_type));
lean_assert(is_constant(get_app_fn(m_new_domain)), get_psum_name());
}
optional<unsigned> get_fidx(expr const & fn) const {
if (!is_local(fn))
return optional<unsigned>();
for (unsigned fidx = 0; fidx < m_ues.get_num_fns(); fidx++) {
if (mlocal_name(m_ues.get_fn(fidx)) == mlocal_name(fn))
return optional<unsigned>(fidx);
}
return optional<unsigned>();
}
expr mk_new_arg(expr const & e, unsigned fidx, unsigned i, expr psum_type) {
if (i == m_ues.get_num_fns() - 1) {
return e;
} else {
psum_type = m_ctx.relaxed_whnf(psum_type);
buffer<expr> args;
get_app_args(psum_type, args);
lean_assert(args.size() == 2);
if (i == fidx) {
return mk_app(m_ctx, get_psum_inl_name(), args[0], args[1], e);
} else {
expr r = mk_new_arg(e, fidx, i+1, args[1]);
return mk_app(m_ctx, get_psum_inr_name(), args[0], args[1], r);
}
}
}
expr mk_new_arg(expr const & e, unsigned fidx) {
return mk_new_arg(e, fidx, 0, m_new_domain);
}
virtual expr visit_app(expr const & e) override {
if (optional<unsigned> fidx = get_fidx(app_fn(e))) {
expr arg = visit(app_arg(e));
expr new_arg = mk_new_arg(arg, *fidx);
return mk_app(m_new_fn, new_arg);
} else {
return replace_visitor_with_tc::visit_app(e);
}
}
virtual expr visit_local(expr const & e) override {
if (get_fidx(e)) {
throw generic_exception(e, "unexpected occurrence of recursive function\n");
} else {
return e;
}
}
};
expr operator()(expr const & e) {
unpack_eqns ues(m_ctx, e);
if (ues.get_num_fns() == 1)
return e;
/* Given
f_1 : Pi (x : A_1), B_1 x
...
f_n : Pi (x : A_n), B_n x
create a function with type
f : Pi (x : psum A_1 ... (psum A_{n-1} A_n)), psum.cases_on x (fun y, B_1 y) (... (fun y, B_n y) ...)
remark: this module assumes the B_i's are in the same universe. */
type_context::tmp_locals locals(m_ctx);
buffer<expr> domains;
buffer<expr> codomains;
level codomains_lvl;
name new_fn_name;
for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) {
expr const & fn = ues.get_fn(fidx);
new_fn_name = new_fn_name + local_pp_name(fn);
lean_assert(ues.get_arity_of(fidx) == 1);
expr fn_type = m_ctx.relaxed_whnf(m_ctx.infer(fn));
lean_assert(is_pi(fn_type));
domains.push_back(binding_domain(fn_type));
expr y = locals.push_local("_s", binding_domain(fn_type));
expr c = instantiate(binding_body(fn_type), y);
level c_lvl = get_level(m_ctx, c);
if (fidx == 0) {
codomains_lvl = c_lvl;
} else if (!m_ctx.is_def_eq(c_lvl, codomains_lvl)) {
throw generic_exception(e, "invalid mutual definition, result types must be in the same universe");
}
codomains.push_back(binding_body(fn_type));
}
expr new_domain = mk_new_domain(domains);
expr x = locals.push_local("_x", new_domain);
expr new_codomain = mk_new_codomain(x, 0, codomains, codomains_lvl);
expr new_fn_type = m_ctx.mk_pi(x, new_codomain);
expr new_fn = locals.push_local(new_fn_name, new_fn_type);
trace_debug_mutual(tout() << "new function " << new_fn_name << " : " << new_fn_type << "\n";);
equations_header new_header = get_equations_header(e);
new_header.m_num_fns = 1;
replace_fns replacer(m_ctx, ues, new_fn);
buffer<expr> new_eqns;
for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) {
buffer<expr> const & eqns = ues.get_eqns_of(fidx);
for (expr const & eqn : eqns) {
unpack_eqn ue(m_ctx, eqn);
expr new_lhs = replacer(ue.lhs());
expr new_rhs = replacer(ue.rhs());
expr new_eqn = mk_equation(new_lhs, new_rhs, ue.ignore_if_unused());
new_eqns.push_back(m_ctx.mk_lambda(new_fn, m_ctx.mk_lambda(ue.get_vars(), new_eqn)));
}
}
expr result;
if (is_wf_equations(e))
result = mk_equations(new_header, new_eqns.size(), new_eqns.data(), equations_wf_rel(e), equations_wf_proof(e));
else
result = mk_equations(new_header, new_eqns.size(), new_eqns.data());
trace_debug_mutual(tout() << "result\n" << result << "\n";);
return result;
}
};
expr pack_mutual(type_context & ctx, expr const & e) {
return pack_mutual_fn(ctx)(e);
}
void initialize_pack_mutual() {
register_trace_class({"debug", "eqn_compiler", "mutual"});
}
void finalize_pack_mutual() {
}
}

View file

@ -0,0 +1,16 @@
/*
Copyright (c) 2017 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include "library/type_context.h"
namespace lean {
/** \brief Create a new equations object containing a single function.
The functions must be unary. */
expr pack_mutual(type_context & ctx, expr const & eqns);
void initialize_pack_mutual();
void finalize_pack_mutual();
}

View file

@ -165,6 +165,7 @@ unpack_eqn::unpack_eqn(type_context & ctx, expr const & eqn):
m_nested_src = it;
m_lhs = equation_lhs(it);
m_rhs = equation_rhs(it);
m_ignore_if_unused = ignore_equation_if_unused(it);
}
expr unpack_eqn::add_var(name const & n, expr const & type) {

View file

@ -59,6 +59,7 @@ class unpack_eqn {
expr m_nested_src;
expr m_lhs;
expr m_rhs;
bool m_ignore_if_unused;
public:
unpack_eqn(type_context & ctx, expr const & eqn);
expr add_var(name const & n, expr const & type);
@ -66,6 +67,7 @@ public:
expr & lhs() { return m_lhs; }
expr & rhs() { return m_rhs; }
expr const & get_nested_src() const { return m_nested_src; }
bool ignore_if_unused() const { return m_ignore_if_unused; }
expr repack();
};

View file

@ -13,6 +13,7 @@ Author: Leonardo de Moura
#include "library/sorry.h" // remove after we add tactic for proving recursive calls are decreasing
#include "library/replace_visitor_with_tc.h"
#include "library/equations_compiler/pack_domain.h"
#include "library/equations_compiler/pack_mutual.h"
#include "library/equations_compiler/elim_match.h"
#include "library/equations_compiler/util.h"
@ -48,7 +49,18 @@ struct wf_rec_fn {
expr pack_domain(expr const & eqns) {
type_context ctx = mk_type_context();
return ::lean::pack_domain(ctx, eqns);
expr r = ::lean::pack_domain(ctx, eqns);
m_env = ctx.env();
m_mctx = ctx.mctx();
return r;
}
expr pack_mutual(expr const & eqns) {
type_context ctx = mk_type_context();
expr r = ::lean::pack_mutual(ctx, eqns);
m_env = ctx.env();
m_mctx = ctx.mctx();
return r;
}
expr_pair mk_wf_relation(expr const & eqns) {
@ -196,8 +208,7 @@ struct wf_rec_fn {
/* Make sure we have only one function */
equations_header const & header = get_equations_header(eqns);
if (header.m_num_fns > 1) {
// TODO(Leo): combine functions
throw exception("support for mutual recursion has not been implemented yet");
eqns = pack_mutual(eqns);
}
/* Retrieve well founded relation */

27
tmp/even_odd.lean Normal file
View file

@ -0,0 +1,27 @@
import data.vector
open nat
universes u v
set_option pp.all true
set_option trace.eqn_compiler.wf_rec true
set_option trace.debug.eqn_compiler.wf_rec true
set_option trace.debug.eqn_compiler.mutual true
set_option trace.eqn_compiler.elim_match true
mutual def even, odd
with even : nat → bool
| 0 := tt
| (a+1) := odd a
with odd : nat → bool
| 0 := ff
| (a+1) := even a
mutual def f, g {α : Type u} {β : Type v} (f : α → β) (p : α × β)
with f : Π n : nat, vector (α × β) n
| 0 := vector.nil
| (succ n) := vector.cons p $ (g n p.1).map (λ b, (p.1, b))
with g : Π n : nat, α → vector β n
| 0 a := vector.nil
| (succ n) a := vector.cons p.2 $ (f n).map (λ p, p.2)