feat(library/equations_compiler): add pack_mutual
This step packs a collection of mutually recursive functions into a single one. We use `psum` to combine the different domains, and `psum.cases_on` to combine the codomains.
This commit is contained in:
parent
22d0dc197c
commit
789d4e148f
9 changed files with 286 additions and 6 deletions
|
|
@ -153,15 +153,21 @@ environment mutual_definition_cmd_core(parser & p, def_cmd_kind kind, decl_modif
|
|||
buffer<expr> fns, params;
|
||||
declaration_info_scope scope(p, kind, modifiers);
|
||||
expr val = parse_mutual_definition(p, lp_names, fns, params);
|
||||
|
||||
// skip elaboration of definitions during reparsing
|
||||
if (p.get_break_at_pos())
|
||||
return p.env();
|
||||
|
||||
bool recover_from_errors = true;
|
||||
elaborator elab(p.env(), p.get_options(), get_namespace(p.env()) + local_pp_name(fns[0]), metavar_context(), local_context(), recover_from_errors);
|
||||
buffer<expr> new_params;
|
||||
elaborate_params(elab, params, new_params);
|
||||
val = replace_locals_preserving_pos_info(val, params, new_params);
|
||||
|
||||
// TODO(Leo)
|
||||
for (auto p : new_params) { tout() << ">> " << p << " : " << mlocal_type(p) << "\n"; }
|
||||
val = elab.elaborate(val);
|
||||
|
||||
tout() << val << "\n";
|
||||
// TODO(Leo)
|
||||
|
||||
return p.env();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,4 +2,4 @@ add_library(equations_compiler OBJECT
|
|||
equations.cpp util.cpp pack_domain.cpp
|
||||
structural_rec.cpp unbounded_rec.cpp
|
||||
elim_match.cpp compiler.cpp wf_rec.cpp
|
||||
init_module.cpp)
|
||||
pack_mutual.cpp init_module.cpp)
|
||||
|
|
|
|||
|
|
@ -9,12 +9,14 @@ Author: Leonardo de Moura
|
|||
#include "library/equations_compiler/structural_rec.h"
|
||||
#include "library/equations_compiler/wf_rec.h"
|
||||
#include "library/equations_compiler/elim_match.h"
|
||||
#include "library/equations_compiler/pack_mutual.h"
|
||||
#include "library/equations_compiler/compiler.h"
|
||||
|
||||
namespace lean{
|
||||
void initialize_equations_compiler_module() {
|
||||
initialize_eqn_compiler_util();
|
||||
initialize_equations();
|
||||
initialize_pack_mutual();
|
||||
initialize_structural_rec();
|
||||
initialize_wf_rec();
|
||||
initialize_elim_match();
|
||||
|
|
@ -26,6 +28,7 @@ void finalize_equations_compiler_module() {
|
|||
finalize_elim_match();
|
||||
finalize_structural_rec();
|
||||
finalize_wf_rec();
|
||||
finalize_pack_mutual();
|
||||
finalize_equations();
|
||||
finalize_eqn_compiler_util();
|
||||
}
|
||||
|
|
|
|||
214
src/library/equations_compiler/pack_mutual.cpp
Normal file
214
src/library/equations_compiler/pack_mutual.cpp
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
/*
|
||||
Copyright (c) 2017 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include "kernel/instantiate.h"
|
||||
#include "library/constants.h"
|
||||
#include "library/trace.h"
|
||||
#include "library/app_builder.h"
|
||||
#include "library/type_context.h"
|
||||
#include "library/locals.h"
|
||||
#include "library/replace_visitor_with_tc.h"
|
||||
#include "library/equations_compiler/equations.h"
|
||||
#include "library/equations_compiler/util.h"
|
||||
|
||||
namespace lean {
|
||||
#define trace_debug_mutual(Code) lean_trace(name({"debug", "eqn_compiler", "mutual"}), scope_trace_env _scope(m_ctx.env(), m_ctx); Code)
|
||||
|
||||
struct pack_mutual_fn {
|
||||
type_context & m_ctx;
|
||||
|
||||
pack_mutual_fn(type_context & ctx):m_ctx(ctx) {}
|
||||
|
||||
expr mk_new_domain(buffer<expr> const & domains) {
|
||||
unsigned i = domains.size();
|
||||
lean_assert(i > 1);
|
||||
--i;
|
||||
expr r = domains[i];
|
||||
while (i > 0) {
|
||||
--i;
|
||||
r = mk_app(m_ctx, get_psum_name(), domains[i], r);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
expr mk_new_codomain(expr const & x, unsigned i, buffer<expr> const & codomains, level codomains_lvl) {
|
||||
if (i == codomains.size() - 1) {
|
||||
return instantiate(codomains[i], x);
|
||||
} else {
|
||||
expr x_type = m_ctx.relaxed_whnf(m_ctx.infer(x));
|
||||
buffer<expr> args;
|
||||
expr psum = get_app_args(x_type, args);
|
||||
lean_assert(const_name(psum) == get_psum_name());
|
||||
lean_assert(args.size() == 2);
|
||||
levels psum_cases_on_lvls(mk_succ(codomains_lvl), const_levels(psum));
|
||||
expr cases_on = mk_constant(get_psum_cases_on_name(), psum_cases_on_lvls);
|
||||
/* Add parameters */
|
||||
cases_on = mk_app(cases_on, args);
|
||||
/* Add motive */
|
||||
expr motive = mk_sort(codomains_lvl);
|
||||
cases_on = mk_app(cases_on, m_ctx.mk_lambda(x, motive));
|
||||
/* Add major */
|
||||
cases_on = mk_app(cases_on, x);
|
||||
/* Add minors */
|
||||
type_context::tmp_locals locals(m_ctx);
|
||||
expr y_1 = locals.push_local("_s", args[0]);
|
||||
expr m_1 = m_ctx.mk_lambda(y_1, instantiate(codomains[i], y_1));
|
||||
expr y_2 = locals.push_local("_s", args[1]);
|
||||
expr m_2 = mk_new_codomain(y_2, i+1, codomains, codomains_lvl);
|
||||
m_2 = m_ctx.mk_lambda(y_2, m_2);
|
||||
return mk_app(cases_on, m_1, m_2);
|
||||
}
|
||||
}
|
||||
|
||||
struct replace_fns : public replace_visitor_with_tc {
|
||||
unpack_eqns const & m_ues;
|
||||
expr m_new_fn;
|
||||
expr m_new_domain;
|
||||
|
||||
replace_fns(type_context & ctx, unpack_eqns const & ues, expr const & new_fn):
|
||||
replace_visitor_with_tc(ctx),
|
||||
m_ues(ues),
|
||||
m_new_fn(new_fn) {
|
||||
expr new_fn_type = m_ctx.relaxed_whnf(m_ctx.infer(m_new_fn));
|
||||
lean_assert(is_pi(new_fn_type));
|
||||
m_new_domain = m_ctx.relaxed_whnf(binding_domain(new_fn_type));
|
||||
lean_assert(is_constant(get_app_fn(m_new_domain)), get_psum_name());
|
||||
}
|
||||
|
||||
optional<unsigned> get_fidx(expr const & fn) const {
|
||||
if (!is_local(fn))
|
||||
return optional<unsigned>();
|
||||
for (unsigned fidx = 0; fidx < m_ues.get_num_fns(); fidx++) {
|
||||
if (mlocal_name(m_ues.get_fn(fidx)) == mlocal_name(fn))
|
||||
return optional<unsigned>(fidx);
|
||||
}
|
||||
return optional<unsigned>();
|
||||
}
|
||||
|
||||
expr mk_new_arg(expr const & e, unsigned fidx, unsigned i, expr psum_type) {
|
||||
if (i == m_ues.get_num_fns() - 1) {
|
||||
return e;
|
||||
} else {
|
||||
psum_type = m_ctx.relaxed_whnf(psum_type);
|
||||
buffer<expr> args;
|
||||
get_app_args(psum_type, args);
|
||||
lean_assert(args.size() == 2);
|
||||
if (i == fidx) {
|
||||
return mk_app(m_ctx, get_psum_inl_name(), args[0], args[1], e);
|
||||
} else {
|
||||
expr r = mk_new_arg(e, fidx, i+1, args[1]);
|
||||
return mk_app(m_ctx, get_psum_inr_name(), args[0], args[1], r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
expr mk_new_arg(expr const & e, unsigned fidx) {
|
||||
return mk_new_arg(e, fidx, 0, m_new_domain);
|
||||
}
|
||||
|
||||
virtual expr visit_app(expr const & e) override {
|
||||
if (optional<unsigned> fidx = get_fidx(app_fn(e))) {
|
||||
expr arg = visit(app_arg(e));
|
||||
expr new_arg = mk_new_arg(arg, *fidx);
|
||||
return mk_app(m_new_fn, new_arg);
|
||||
} else {
|
||||
return replace_visitor_with_tc::visit_app(e);
|
||||
}
|
||||
}
|
||||
|
||||
virtual expr visit_local(expr const & e) override {
|
||||
if (get_fidx(e)) {
|
||||
throw generic_exception(e, "unexpected occurrence of recursive function\n");
|
||||
} else {
|
||||
return e;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
expr operator()(expr const & e) {
|
||||
unpack_eqns ues(m_ctx, e);
|
||||
if (ues.get_num_fns() == 1)
|
||||
return e;
|
||||
/* Given
|
||||
f_1 : Pi (x : A_1), B_1 x
|
||||
...
|
||||
f_n : Pi (x : A_n), B_n x
|
||||
|
||||
create a function with type
|
||||
f : Pi (x : psum A_1 ... (psum A_{n-1} A_n)), psum.cases_on x (fun y, B_1 y) (... (fun y, B_n y) ...)
|
||||
|
||||
remark: this module assumes the B_i's are in the same universe. */
|
||||
type_context::tmp_locals locals(m_ctx);
|
||||
buffer<expr> domains;
|
||||
buffer<expr> codomains;
|
||||
level codomains_lvl;
|
||||
name new_fn_name;
|
||||
for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) {
|
||||
expr const & fn = ues.get_fn(fidx);
|
||||
new_fn_name = new_fn_name + local_pp_name(fn);
|
||||
lean_assert(ues.get_arity_of(fidx) == 1);
|
||||
expr fn_type = m_ctx.relaxed_whnf(m_ctx.infer(fn));
|
||||
lean_assert(is_pi(fn_type));
|
||||
domains.push_back(binding_domain(fn_type));
|
||||
expr y = locals.push_local("_s", binding_domain(fn_type));
|
||||
expr c = instantiate(binding_body(fn_type), y);
|
||||
level c_lvl = get_level(m_ctx, c);
|
||||
if (fidx == 0) {
|
||||
codomains_lvl = c_lvl;
|
||||
} else if (!m_ctx.is_def_eq(c_lvl, codomains_lvl)) {
|
||||
throw generic_exception(e, "invalid mutual definition, result types must be in the same universe");
|
||||
}
|
||||
codomains.push_back(binding_body(fn_type));
|
||||
}
|
||||
|
||||
expr new_domain = mk_new_domain(domains);
|
||||
expr x = locals.push_local("_x", new_domain);
|
||||
expr new_codomain = mk_new_codomain(x, 0, codomains, codomains_lvl);
|
||||
expr new_fn_type = m_ctx.mk_pi(x, new_codomain);
|
||||
expr new_fn = locals.push_local(new_fn_name, new_fn_type);
|
||||
|
||||
trace_debug_mutual(tout() << "new function " << new_fn_name << " : " << new_fn_type << "\n";);
|
||||
|
||||
equations_header new_header = get_equations_header(e);
|
||||
new_header.m_num_fns = 1;
|
||||
|
||||
replace_fns replacer(m_ctx, ues, new_fn);
|
||||
|
||||
buffer<expr> new_eqns;
|
||||
for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) {
|
||||
buffer<expr> const & eqns = ues.get_eqns_of(fidx);
|
||||
for (expr const & eqn : eqns) {
|
||||
unpack_eqn ue(m_ctx, eqn);
|
||||
expr new_lhs = replacer(ue.lhs());
|
||||
expr new_rhs = replacer(ue.rhs());
|
||||
expr new_eqn = mk_equation(new_lhs, new_rhs, ue.ignore_if_unused());
|
||||
new_eqns.push_back(m_ctx.mk_lambda(new_fn, m_ctx.mk_lambda(ue.get_vars(), new_eqn)));
|
||||
}
|
||||
}
|
||||
|
||||
expr result;
|
||||
if (is_wf_equations(e))
|
||||
result = mk_equations(new_header, new_eqns.size(), new_eqns.data(), equations_wf_rel(e), equations_wf_proof(e));
|
||||
else
|
||||
result = mk_equations(new_header, new_eqns.size(), new_eqns.data());
|
||||
|
||||
trace_debug_mutual(tout() << "result\n" << result << "\n";);
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
expr pack_mutual(type_context & ctx, expr const & e) {
|
||||
return pack_mutual_fn(ctx)(e);
|
||||
}
|
||||
|
||||
void initialize_pack_mutual() {
|
||||
register_trace_class({"debug", "eqn_compiler", "mutual"});
|
||||
}
|
||||
|
||||
void finalize_pack_mutual() {
|
||||
}
|
||||
}
|
||||
16
src/library/equations_compiler/pack_mutual.h
Normal file
16
src/library/equations_compiler/pack_mutual.h
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
/*
|
||||
Copyright (c) 2017 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#pragma once
|
||||
#include "library/type_context.h"
|
||||
namespace lean {
|
||||
/** \brief Create a new equations object containing a single function.
|
||||
The functions must be unary. */
|
||||
expr pack_mutual(type_context & ctx, expr const & eqns);
|
||||
|
||||
void initialize_pack_mutual();
|
||||
void finalize_pack_mutual();
|
||||
}
|
||||
|
|
@ -165,6 +165,7 @@ unpack_eqn::unpack_eqn(type_context & ctx, expr const & eqn):
|
|||
m_nested_src = it;
|
||||
m_lhs = equation_lhs(it);
|
||||
m_rhs = equation_rhs(it);
|
||||
m_ignore_if_unused = ignore_equation_if_unused(it);
|
||||
}
|
||||
|
||||
expr unpack_eqn::add_var(name const & n, expr const & type) {
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ class unpack_eqn {
|
|||
expr m_nested_src;
|
||||
expr m_lhs;
|
||||
expr m_rhs;
|
||||
bool m_ignore_if_unused;
|
||||
public:
|
||||
unpack_eqn(type_context & ctx, expr const & eqn);
|
||||
expr add_var(name const & n, expr const & type);
|
||||
|
|
@ -66,6 +67,7 @@ public:
|
|||
expr & lhs() { return m_lhs; }
|
||||
expr & rhs() { return m_rhs; }
|
||||
expr const & get_nested_src() const { return m_nested_src; }
|
||||
bool ignore_if_unused() const { return m_ignore_if_unused; }
|
||||
expr repack();
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ Author: Leonardo de Moura
|
|||
#include "library/sorry.h" // remove after we add tactic for proving recursive calls are decreasing
|
||||
#include "library/replace_visitor_with_tc.h"
|
||||
#include "library/equations_compiler/pack_domain.h"
|
||||
#include "library/equations_compiler/pack_mutual.h"
|
||||
#include "library/equations_compiler/elim_match.h"
|
||||
#include "library/equations_compiler/util.h"
|
||||
|
||||
|
|
@ -48,7 +49,18 @@ struct wf_rec_fn {
|
|||
|
||||
expr pack_domain(expr const & eqns) {
|
||||
type_context ctx = mk_type_context();
|
||||
return ::lean::pack_domain(ctx, eqns);
|
||||
expr r = ::lean::pack_domain(ctx, eqns);
|
||||
m_env = ctx.env();
|
||||
m_mctx = ctx.mctx();
|
||||
return r;
|
||||
}
|
||||
|
||||
expr pack_mutual(expr const & eqns) {
|
||||
type_context ctx = mk_type_context();
|
||||
expr r = ::lean::pack_mutual(ctx, eqns);
|
||||
m_env = ctx.env();
|
||||
m_mctx = ctx.mctx();
|
||||
return r;
|
||||
}
|
||||
|
||||
expr_pair mk_wf_relation(expr const & eqns) {
|
||||
|
|
@ -196,8 +208,7 @@ struct wf_rec_fn {
|
|||
/* Make sure we have only one function */
|
||||
equations_header const & header = get_equations_header(eqns);
|
||||
if (header.m_num_fns > 1) {
|
||||
// TODO(Leo): combine functions
|
||||
throw exception("support for mutual recursion has not been implemented yet");
|
||||
eqns = pack_mutual(eqns);
|
||||
}
|
||||
|
||||
/* Retrieve well founded relation */
|
||||
|
|
|
|||
27
tmp/even_odd.lean
Normal file
27
tmp/even_odd.lean
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import data.vector
|
||||
open nat
|
||||
universes u v
|
||||
set_option pp.all true
|
||||
|
||||
set_option trace.eqn_compiler.wf_rec true
|
||||
set_option trace.debug.eqn_compiler.wf_rec true
|
||||
set_option trace.debug.eqn_compiler.mutual true
|
||||
set_option trace.eqn_compiler.elim_match true
|
||||
|
||||
|
||||
mutual def even, odd
|
||||
with even : nat → bool
|
||||
| 0 := tt
|
||||
| (a+1) := odd a
|
||||
with odd : nat → bool
|
||||
| 0 := ff
|
||||
| (a+1) := even a
|
||||
|
||||
|
||||
mutual def f, g {α : Type u} {β : Type v} (f : α → β) (p : α × β)
|
||||
with f : Π n : nat, vector (α × β) n
|
||||
| 0 := vector.nil
|
||||
| (succ n) := vector.cons p $ (g n p.1).map (λ b, (p.1, b))
|
||||
with g : Π n : nat, α → vector β n
|
||||
| 0 a := vector.nil
|
||||
| (succ n) a := vector.cons p.2 $ (f n).map (λ p, p.2)
|
||||
Loading…
Add table
Reference in a new issue