feat(library/compiler/eager_lambda_lifting): do not pass closed terms as arguments to lifted decl
This commit is contained in:
parent
0f873b7ba2
commit
f1a2c83e8c
1 changed files with 84 additions and 7 deletions
|
|
@ -27,6 +27,33 @@ bool is_elambda_lifting_name(name fn) {
|
|||
return is_eager_lambda_lifting_name_core(fn.to_obj_arg());
|
||||
}
|
||||
|
||||
/* Return true iff `e` contains a free variable that is not in `exception_set`. */
|
||||
static bool has_fvar_except(expr const & e, name_set const & exception_set) {
|
||||
if (!has_fvar(e)) return false;
|
||||
bool found = false;
|
||||
for_each(e, [&](expr const & e, unsigned) {
|
||||
if (!has_fvar(e)) return false;
|
||||
if (found) return false; // done
|
||||
if (is_fvar(e) && !exception_set.contains(fvar_name(e))) {
|
||||
found = true;
|
||||
return false; // done
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
/* Return true if the type of a parameter in `params` depends on `fvar`. */
|
||||
static bool depends_on_fvar(local_ctx const & lctx, buffer<expr> const & params, expr const & fvar) {
|
||||
for (expr const & param : params) {
|
||||
local_decl const & decl = lctx.get_local_decl(param);
|
||||
lean_assert(!decl.get_value());
|
||||
if (has_fvar(decl.get_type(), fvar))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/*
|
||||
We eagerly lift lambda expressions that are stored in terminal constructors.
|
||||
We say a constructor application is terminal if it is the result/returned.
|
||||
|
|
@ -74,6 +101,7 @@ class eager_lambda_lifting_fn {
|
|||
local_ctx m_lctx;
|
||||
buffer<comp_decl> m_new_decls;
|
||||
name m_base_name;
|
||||
name_set m_closed_fvars; /* let-declarations that only depend on global constants and other closed_fvars */
|
||||
name_set m_terminal_lambdas;
|
||||
name_set m_nonterminal_lambdas;
|
||||
unsigned m_next_idx{1};
|
||||
|
|
@ -88,7 +116,7 @@ class eager_lambda_lifting_fn {
|
|||
return r;
|
||||
}
|
||||
|
||||
bool collect_fvars_core(expr const & e, name_set collected, buffer<expr> & fvars) {
|
||||
bool collect_fvars_core(expr const & e, name_set & collected, buffer<expr> & fvars) {
|
||||
if (!has_fvar(e)) return true;
|
||||
bool ok = true;
|
||||
for_each(e, [&](expr const & x, unsigned) {
|
||||
|
|
@ -108,6 +136,12 @@ class eager_lambda_lifting_fn {
|
|||
} else {
|
||||
if (!collect_fvars_core(d.get_type(), collected, fvars))
|
||||
return false;
|
||||
if (m_closed_fvars.contains(fvar_name(x))) {
|
||||
/* If x only depends on global constants and other variables in m_closed_fvars.
|
||||
Then, we also collect the other variables at m_closed_fvars. */
|
||||
if (!collect_fvars_core(*d.get_value(), collected, fvars))
|
||||
return false;
|
||||
}
|
||||
fvars.push_back(x);
|
||||
}
|
||||
}
|
||||
|
|
@ -128,18 +162,58 @@ class eager_lambda_lifting_fn {
|
|||
}
|
||||
}
|
||||
|
||||
expr lift_lambda(expr const & e) {
|
||||
/* Split fvars in two groups: `new_params` and `to_copy`.
|
||||
We put a fvar `x` in `new_params` if it is not a let declaration,
|
||||
or a variable in `params` depend on `x`, or it is not in `m_closed_fvars`.
|
||||
|
||||
The variables in `to_copy` are variables that depend only on
|
||||
global constants or other variables in `to_copy`, and `params` do not depend on them. */
|
||||
void split_fvars(buffer<expr> const & fvars, buffer<expr> const & params, buffer<expr> & new_params, buffer<expr> & to_copy) {
|
||||
for (expr const & fvar : fvars) {
|
||||
local_decl const & decl = m_lctx.get_local_decl(fvar);
|
||||
if (!decl.get_value()) {
|
||||
new_params.push_back(fvar);
|
||||
} else {
|
||||
if (!m_closed_fvars.contains(fvar_name(fvar)) || depends_on_fvar(m_lctx, params, fvar)) {
|
||||
new_params.push_back(fvar);
|
||||
} else {
|
||||
to_copy.push_back(fvar);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
expr lift_lambda(expr e) {
|
||||
lean_assert(is_lambda(e));
|
||||
buffer<expr> fvars;
|
||||
if (!collect_fvars(e, fvars)) {
|
||||
return e;
|
||||
}
|
||||
expr code = abstract(e, fvars.size(), fvars.data());
|
||||
unsigned i = fvars.size();
|
||||
buffer<expr> params;
|
||||
while (is_lambda(e)) {
|
||||
expr param_type = instantiate_rev(binding_domain(e), params.size(), params.data());
|
||||
expr param = m_lctx.mk_local_decl(ngen(), binding_name(e), param_type, binding_info(e));
|
||||
params.push_back(param);
|
||||
e = binding_body(e);
|
||||
}
|
||||
e = instantiate_rev(e, params.size(), params.data());
|
||||
buffer<expr> new_params, to_copy;
|
||||
split_fvars(fvars, params, new_params, to_copy);
|
||||
/*
|
||||
Variables in `to_copy` only depend on global constants
|
||||
and other variables in `to_copy`. Moreover, `params` do not depend on them.
|
||||
It is wasteful to pass them as new parameters to the new lifted declaration.
|
||||
We can just copy them. The code duplication is not problematic because later at `extract_closed`
|
||||
we will create global names for closed terms, and eliminate the redundancy.
|
||||
*/
|
||||
e = m_lctx.mk_lambda(to_copy, e);
|
||||
e = m_lctx.mk_lambda(params, e);
|
||||
expr code = abstract(e, new_params.size(), new_params.data());
|
||||
unsigned i = new_params.size();
|
||||
while (i > 0) {
|
||||
--i;
|
||||
local_decl const & decl = m_lctx.get_local_decl(fvars[i]);
|
||||
expr type = abstract(decl.get_type(), i, fvars.data());
|
||||
local_decl const & decl = m_lctx.get_local_decl(new_params[i]);
|
||||
expr type = abstract(decl.get_type(), i, new_params.data());
|
||||
code = ::lean::mk_lambda(decl.get_user_name(), type, code);
|
||||
}
|
||||
expr type = cheap_beta_reduce(type_checker(m_st).infer(code));
|
||||
|
|
@ -151,7 +225,7 @@ class eager_lambda_lifting_fn {
|
|||
declaration aux_ax = mk_axiom(n, names(), type, true /* meta */);
|
||||
m_st.env() = env().add(aux_ax, false);
|
||||
m_new_decls.push_back(comp_decl(n, code));
|
||||
return mk_app(mk_constant(n), fvars);
|
||||
return mk_app(mk_constant(n), new_params);
|
||||
}
|
||||
|
||||
/* Given a free variable `x`, follow let-decls and return a pair `(x, v)`.
|
||||
|
|
@ -208,6 +282,9 @@ class eager_lambda_lifting_fn {
|
|||
expr new_type = instantiate_rev(let_type(e), fvars.size(), fvars.data());
|
||||
expr new_val = visit(instantiate_rev(let_value(e), fvars.size(), fvars.data()), not_root, jp);
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), let_name(e), new_type, new_val);
|
||||
if (!has_fvar_except(new_type, m_closed_fvars) && !has_fvar_except(new_val, m_closed_fvars)) {
|
||||
m_closed_fvars.insert(fvar_name(new_fvar));
|
||||
}
|
||||
fvars.push_back(new_fvar);
|
||||
e = let_body(e);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue