diff --git a/src/library/compiler/lambda_lifting.cpp b/src/library/compiler/lambda_lifting.cpp index 003cc89472..f1506cc484 100644 --- a/src/library/compiler/lambda_lifting.cpp +++ b/src/library/compiler/lambda_lifting.cpp @@ -9,6 +9,7 @@ Author: Leonardo de Moura #include "kernel/instantiate.h" #include "kernel/abstract.h" #include "kernel/for_each_fn.h" +#include "library/trace.h" #include "library/compiler/util.h" namespace lean { @@ -19,9 +20,6 @@ class lambda_lifting_fn { buffer m_new_decls; name m_base_name; unsigned m_next_idx{1}; - name m_y; - unsigned m_next_let_idx{1}; - typedef std::unordered_set name_set; @@ -42,12 +40,6 @@ class lambda_lifting_fn { return m_lctx.mk_lambda(fvars, r); } - name next_let_name() { - name r = m_y.append_after(m_next_let_idx); - m_next_let_idx++; - return r; - } - expr visit_let(expr e) { flet save_lctx(m_lctx, m_lctx); buffer fvars; @@ -56,7 +48,7 @@ class lambda_lifting_fn { bool not_root = false; bool jp = is_join_point_name(let_name(e)); expr new_val = visit(instantiate_rev(let_value(e), fvars.size(), fvars.data()), not_root, jp); - expr new_fvar = m_lctx.mk_local_decl(ngen(), next_let_name(), let_type(e), new_val); + expr new_fvar = m_lctx.mk_local_decl(ngen(), let_name(e), let_type(e), new_val); fvars.push_back(new_fvar); e = let_body(e); } @@ -84,21 +76,34 @@ class lambda_lifting_fn { } } - void collect_fvars(expr const & e, buffer & fvars) { + void collect_fvars_core(expr const & e, name_set collected, buffer & fvars, buffer & jps) { if (!has_fvar(e)) return; - name_set collected; for_each(e, [&](expr const & x, unsigned) { if (!has_fvar(x)) return false; if (is_fvar(x)) { if (collected.find(fvar_name(x)) == collected.end()) { collected.insert(fvar_name(x)); - fvars.push_back(x); + local_decl d = m_lctx.get_local_decl(x); + /* We MUST copy any join point that lambda expression depends on, and + its dependencies. */ + if (is_join_point_name(d.get_user_name())) { + collect_fvars_core(*d.get_value(), collected, fvars, jps); + jps.push_back(x); + } else { + fvars.push_back(x); + } } } return true; }); } + void collect_fvars(expr const & e, buffer & fvars, buffer & jps) { + if (!has_fvar(e)) return; + name_set collected; + collect_fvars_core(e, collected, fvars, jps); + } + /* Try to apply eta-reduction to reduce number of auxiliary declarations. */ optional try_eta_reduction(expr const & e) { expr r = ::lean::try_eta(e); @@ -115,19 +120,38 @@ class lambda_lifting_fn { return none_expr(); } - name next_name(bool join_point) { - name r(m_base_name, join_point ? "_join_point" : "_lambda"); + name next_name() { + name r(m_base_name, "_lambda"); r = r.append_after(m_next_idx); m_next_idx++; return r; } - /* Creates `fun , e`. Remark: it is different from `m_lctx.mk_lambda` because - it will create a lambda expression even if for free variables in `fvars` that correspond - to let declarations. */ - expr mk_lambda(buffer const & fvars, expr e) { + /* Given `e` of the form `fun xs, t`, create `fun fvars xs, let jps in e`. */ + expr mk_lambda(buffer const & fvars, buffer const & jps, expr e) { + flet save_lctx(m_lctx, m_lctx); + buffer xs; + while (is_lambda(e)) { + lean_assert(!has_loose_bvars(binding_domain(e))); + expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), binding_domain(e)); + xs.push_back(new_fvar); + e = binding_body(e); + } + e = instantiate_rev(e, xs.size(), xs.data()); + e = abstract(e, jps.size(), jps.data()); + unsigned i = jps.size(); + while (i > 0) { + --i; + expr const & fvar = jps[i]; + local_decl decl = m_lctx.get_local_decl(fvar); + lean_assert(is_join_point_name(decl.get_user_name())); + lean_assert(!has_loose_bvars(decl.get_type())); + expr val = abstract(*decl.get_value(), i, jps.data()); + e = ::lean::mk_let(decl.get_user_name(), decl.get_type(), val, e); + } + e = m_lctx.mk_lambda(xs, e); e = abstract(e, fvars.size(), fvars.data()); - unsigned i = fvars.size(); + i = fvars.size(); while (i > 0) { --i; expr const & fvar = fvars[i]; @@ -140,14 +164,14 @@ class lambda_lifting_fn { expr visit_lambda(expr e, bool root, bool join_point) { e = visit_lambda_core(e); - if (root) + if (root || join_point) return e; if (optional r = try_eta_reduction(e)) return *r; - buffer fvars; - collect_fvars(e, fvars); - e = mk_lambda(fvars, e); - name new_fn = next_name(join_point); + buffer fvars; buffer jps; + collect_fvars(e, fvars, jps); + e = mk_lambda(fvars, jps, e); + name new_fn = next_name(); m_new_decls.push_back(comp_decl(new_fn, e)); return mk_app(mk_constant(new_fn), fvars); } @@ -163,7 +187,7 @@ class lambda_lifting_fn { public: lambda_lifting_fn(environment const & env): - m_env(env), m_y("_y") {} + m_env(env) {} comp_decls operator()(comp_decl const & cdecl) { m_base_name = cdecl.fst();