diff --git a/src/library/compiler/eager_lambda_lifting.cpp b/src/library/compiler/eager_lambda_lifting.cpp index fb44b77c6f..84e8585c36 100644 --- a/src/library/compiler/eager_lambda_lifting.cpp +++ b/src/library/compiler/eager_lambda_lifting.cpp @@ -27,6 +27,33 @@ bool is_elambda_lifting_name(name fn) { return is_eager_lambda_lifting_name_core(fn.to_obj_arg()); } +/* Return true iff `e` contains a free variable that is not in `exception_set`. */ +static bool has_fvar_except(expr const & e, name_set const & exception_set) { + if (!has_fvar(e)) return false; + bool found = false; + for_each(e, [&](expr const & e, unsigned) { + if (!has_fvar(e)) return false; + if (found) return false; // done + if (is_fvar(e) && !exception_set.contains(fvar_name(e))) { + found = true; + return false; // done + } + return true; + }); + return found; +} + +/* Return true if the type of a parameter in `params` depends on `fvar`. */ +static bool depends_on_fvar(local_ctx const & lctx, buffer const & params, expr const & fvar) { + for (expr const & param : params) { + local_decl const & decl = lctx.get_local_decl(param); + lean_assert(!decl.get_value()); + if (has_fvar(decl.get_type(), fvar)) + return true; + } + return false; +} + /* We eagerly lift lambda expressions that are stored in terminal constructors. We say a constructor application is terminal if it is the result/returned. @@ -74,6 +101,7 @@ class eager_lambda_lifting_fn { local_ctx m_lctx; buffer m_new_decls; name m_base_name; + name_set m_closed_fvars; /* let-declarations that only depend on global constants and other closed_fvars */ name_set m_terminal_lambdas; name_set m_nonterminal_lambdas; unsigned m_next_idx{1}; @@ -88,7 +116,7 @@ class eager_lambda_lifting_fn { return r; } - bool collect_fvars_core(expr const & e, name_set collected, buffer & fvars) { + bool collect_fvars_core(expr const & e, name_set & collected, buffer & fvars) { if (!has_fvar(e)) return true; bool ok = true; for_each(e, [&](expr const & x, unsigned) { @@ -108,6 +136,12 @@ class eager_lambda_lifting_fn { } else { if (!collect_fvars_core(d.get_type(), collected, fvars)) return false; + if (m_closed_fvars.contains(fvar_name(x))) { + /* If x only depends on global constants and other variables in m_closed_fvars. + Then, we also collect the other variables at m_closed_fvars. */ + if (!collect_fvars_core(*d.get_value(), collected, fvars)) + return false; + } fvars.push_back(x); } } @@ -128,18 +162,58 @@ class eager_lambda_lifting_fn { } } - expr lift_lambda(expr const & e) { + /* Split fvars in two groups: `new_params` and `to_copy`. + We put a fvar `x` in `new_params` if it is not a let declaration, + or a variable in `params` depend on `x`, or it is not in `m_closed_fvars`. + + The variables in `to_copy` are variables that depend only on + global constants or other variables in `to_copy`, and `params` do not depend on them. */ + void split_fvars(buffer const & fvars, buffer const & params, buffer & new_params, buffer & to_copy) { + for (expr const & fvar : fvars) { + local_decl const & decl = m_lctx.get_local_decl(fvar); + if (!decl.get_value()) { + new_params.push_back(fvar); + } else { + if (!m_closed_fvars.contains(fvar_name(fvar)) || depends_on_fvar(m_lctx, params, fvar)) { + new_params.push_back(fvar); + } else { + to_copy.push_back(fvar); + } + } + } + } + + expr lift_lambda(expr e) { lean_assert(is_lambda(e)); buffer fvars; if (!collect_fvars(e, fvars)) { return e; } - expr code = abstract(e, fvars.size(), fvars.data()); - unsigned i = fvars.size(); + buffer params; + while (is_lambda(e)) { + expr param_type = instantiate_rev(binding_domain(e), params.size(), params.data()); + expr param = m_lctx.mk_local_decl(ngen(), binding_name(e), param_type, binding_info(e)); + params.push_back(param); + e = binding_body(e); + } + e = instantiate_rev(e, params.size(), params.data()); + buffer new_params, to_copy; + split_fvars(fvars, params, new_params, to_copy); + /* + Variables in `to_copy` only depend on global constants + and other variables in `to_copy`. Moreover, `params` do not depend on them. + It is wasteful to pass them as new parameters to the new lifted declaration. + We can just copy them. The code duplication is not problematic because later at `extract_closed` + we will create global names for closed terms, and eliminate the redundancy. + */ + e = m_lctx.mk_lambda(to_copy, e); + e = m_lctx.mk_lambda(params, e); + expr code = abstract(e, new_params.size(), new_params.data()); + unsigned i = new_params.size(); while (i > 0) { --i; - local_decl const & decl = m_lctx.get_local_decl(fvars[i]); - expr type = abstract(decl.get_type(), i, fvars.data()); + local_decl const & decl = m_lctx.get_local_decl(new_params[i]); + expr type = abstract(decl.get_type(), i, new_params.data()); code = ::lean::mk_lambda(decl.get_user_name(), type, code); } expr type = cheap_beta_reduce(type_checker(m_st).infer(code)); @@ -151,7 +225,7 @@ class eager_lambda_lifting_fn { declaration aux_ax = mk_axiom(n, names(), type, true /* meta */); m_st.env() = env().add(aux_ax, false); m_new_decls.push_back(comp_decl(n, code)); - return mk_app(mk_constant(n), fvars); + return mk_app(mk_constant(n), new_params); } /* Given a free variable `x`, follow let-decls and return a pair `(x, v)`. @@ -208,6 +282,9 @@ class eager_lambda_lifting_fn { expr new_type = instantiate_rev(let_type(e), fvars.size(), fvars.data()); expr new_val = visit(instantiate_rev(let_value(e), fvars.size(), fvars.data()), not_root, jp); expr new_fvar = m_lctx.mk_local_decl(ngen(), let_name(e), new_type, new_val); + if (!has_fvar_except(new_type, m_closed_fvars) && !has_fvar_except(new_val, m_closed_fvars)) { + m_closed_fvars.insert(fvar_name(new_fvar)); + } fvars.push_back(new_fvar); e = let_body(e); }