feat(library/compiler/lambda_lifting): preserve join points when performing lambda lifting
This commit is contained in:
parent
adff01bba4
commit
b23251fd6e
1 changed files with 50 additions and 26 deletions
|
|
@ -9,6 +9,7 @@ Author: Leonardo de Moura
|
|||
#include "kernel/instantiate.h"
|
||||
#include "kernel/abstract.h"
|
||||
#include "kernel/for_each_fn.h"
|
||||
#include "library/trace.h"
|
||||
#include "library/compiler/util.h"
|
||||
|
||||
namespace lean {
|
||||
|
|
@ -19,9 +20,6 @@ class lambda_lifting_fn {
|
|||
buffer<comp_decl> m_new_decls;
|
||||
name m_base_name;
|
||||
unsigned m_next_idx{1};
|
||||
name m_y;
|
||||
unsigned m_next_let_idx{1};
|
||||
|
||||
|
||||
typedef std::unordered_set<name, name_hash> name_set;
|
||||
|
||||
|
|
@ -42,12 +40,6 @@ class lambda_lifting_fn {
|
|||
return m_lctx.mk_lambda(fvars, r);
|
||||
}
|
||||
|
||||
name next_let_name() {
|
||||
name r = m_y.append_after(m_next_let_idx);
|
||||
m_next_let_idx++;
|
||||
return r;
|
||||
}
|
||||
|
||||
expr visit_let(expr e) {
|
||||
flet<local_ctx> save_lctx(m_lctx, m_lctx);
|
||||
buffer<expr> fvars;
|
||||
|
|
@ -56,7 +48,7 @@ class lambda_lifting_fn {
|
|||
bool not_root = false;
|
||||
bool jp = is_join_point_name(let_name(e));
|
||||
expr new_val = visit(instantiate_rev(let_value(e), fvars.size(), fvars.data()), not_root, jp);
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), next_let_name(), let_type(e), new_val);
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), let_name(e), let_type(e), new_val);
|
||||
fvars.push_back(new_fvar);
|
||||
e = let_body(e);
|
||||
}
|
||||
|
|
@ -84,21 +76,34 @@ class lambda_lifting_fn {
|
|||
}
|
||||
}
|
||||
|
||||
void collect_fvars(expr const & e, buffer<expr> & fvars) {
|
||||
void collect_fvars_core(expr const & e, name_set collected, buffer<expr> & fvars, buffer<expr> & jps) {
|
||||
if (!has_fvar(e)) return;
|
||||
name_set collected;
|
||||
for_each(e, [&](expr const & x, unsigned) {
|
||||
if (!has_fvar(x)) return false;
|
||||
if (is_fvar(x)) {
|
||||
if (collected.find(fvar_name(x)) == collected.end()) {
|
||||
collected.insert(fvar_name(x));
|
||||
fvars.push_back(x);
|
||||
local_decl d = m_lctx.get_local_decl(x);
|
||||
/* We MUST copy any join point that lambda expression depends on, and
|
||||
its dependencies. */
|
||||
if (is_join_point_name(d.get_user_name())) {
|
||||
collect_fvars_core(*d.get_value(), collected, fvars, jps);
|
||||
jps.push_back(x);
|
||||
} else {
|
||||
fvars.push_back(x);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
void collect_fvars(expr const & e, buffer<expr> & fvars, buffer<expr> & jps) {
|
||||
if (!has_fvar(e)) return;
|
||||
name_set collected;
|
||||
collect_fvars_core(e, collected, fvars, jps);
|
||||
}
|
||||
|
||||
/* Try to apply eta-reduction to reduce number of auxiliary declarations. */
|
||||
optional<expr> try_eta_reduction(expr const & e) {
|
||||
expr r = ::lean::try_eta(e);
|
||||
|
|
@ -115,19 +120,38 @@ class lambda_lifting_fn {
|
|||
return none_expr();
|
||||
}
|
||||
|
||||
name next_name(bool join_point) {
|
||||
name r(m_base_name, join_point ? "_join_point" : "_lambda");
|
||||
name next_name() {
|
||||
name r(m_base_name, "_lambda");
|
||||
r = r.append_after(m_next_idx);
|
||||
m_next_idx++;
|
||||
return r;
|
||||
}
|
||||
|
||||
/* Creates `fun <fvar>, e`. Remark: it is different from `m_lctx.mk_lambda` because
|
||||
it will create a lambda expression even if for free variables in `fvars` that correspond
|
||||
to let declarations. */
|
||||
expr mk_lambda(buffer<expr> const & fvars, expr e) {
|
||||
/* Given `e` of the form `fun xs, t`, create `fun fvars xs, let jps in e`. */
|
||||
expr mk_lambda(buffer<expr> const & fvars, buffer<expr> const & jps, expr e) {
|
||||
flet<local_ctx> save_lctx(m_lctx, m_lctx);
|
||||
buffer<expr> xs;
|
||||
while (is_lambda(e)) {
|
||||
lean_assert(!has_loose_bvars(binding_domain(e)));
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), binding_domain(e));
|
||||
xs.push_back(new_fvar);
|
||||
e = binding_body(e);
|
||||
}
|
||||
e = instantiate_rev(e, xs.size(), xs.data());
|
||||
e = abstract(e, jps.size(), jps.data());
|
||||
unsigned i = jps.size();
|
||||
while (i > 0) {
|
||||
--i;
|
||||
expr const & fvar = jps[i];
|
||||
local_decl decl = m_lctx.get_local_decl(fvar);
|
||||
lean_assert(is_join_point_name(decl.get_user_name()));
|
||||
lean_assert(!has_loose_bvars(decl.get_type()));
|
||||
expr val = abstract(*decl.get_value(), i, jps.data());
|
||||
e = ::lean::mk_let(decl.get_user_name(), decl.get_type(), val, e);
|
||||
}
|
||||
e = m_lctx.mk_lambda(xs, e);
|
||||
e = abstract(e, fvars.size(), fvars.data());
|
||||
unsigned i = fvars.size();
|
||||
i = fvars.size();
|
||||
while (i > 0) {
|
||||
--i;
|
||||
expr const & fvar = fvars[i];
|
||||
|
|
@ -140,14 +164,14 @@ class lambda_lifting_fn {
|
|||
|
||||
expr visit_lambda(expr e, bool root, bool join_point) {
|
||||
e = visit_lambda_core(e);
|
||||
if (root)
|
||||
if (root || join_point)
|
||||
return e;
|
||||
if (optional<expr> r = try_eta_reduction(e))
|
||||
return *r;
|
||||
buffer<expr> fvars;
|
||||
collect_fvars(e, fvars);
|
||||
e = mk_lambda(fvars, e);
|
||||
name new_fn = next_name(join_point);
|
||||
buffer<expr> fvars; buffer<expr> jps;
|
||||
collect_fvars(e, fvars, jps);
|
||||
e = mk_lambda(fvars, jps, e);
|
||||
name new_fn = next_name();
|
||||
m_new_decls.push_back(comp_decl(new_fn, e));
|
||||
return mk_app(mk_constant(new_fn), fvars);
|
||||
}
|
||||
|
|
@ -163,7 +187,7 @@ class lambda_lifting_fn {
|
|||
|
||||
public:
|
||||
lambda_lifting_fn(environment const & env):
|
||||
m_env(env), m_y("_y") {}
|
||||
m_env(env) {}
|
||||
|
||||
comp_decls operator()(comp_decl const & cdecl) {
|
||||
m_base_name = cdecl.fst();
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue