feat(library/compiler/csimp): add transformation to complement eager lambda lifting
This commit is contained in:
parent
e3dfc73b3a
commit
0ea944ad9f
2 changed files with 120 additions and 9 deletions
|
|
@ -995,6 +995,17 @@ class csimp_fn {
|
|||
return mk_let_core(entries, e);
|
||||
}
|
||||
|
||||
name mk_let_name(name const & n) {
|
||||
if (is_internal_name(n)) {
|
||||
if (is_join_point_name(n))
|
||||
return next_jp_name();
|
||||
else
|
||||
return next_name();
|
||||
} else {
|
||||
return n;
|
||||
}
|
||||
}
|
||||
|
||||
expr visit_let(expr e) {
|
||||
buffer<expr> let_fvars;
|
||||
while (is_let(e)) {
|
||||
|
|
@ -1003,13 +1014,7 @@ class csimp_fn {
|
|||
if (is_lcnf_atom(new_val)) {
|
||||
let_fvars.push_back(new_val);
|
||||
} else {
|
||||
name n = let_name(e);
|
||||
if (is_internal_name(n)) {
|
||||
if (is_join_point_name(n))
|
||||
n = next_jp_name();
|
||||
else
|
||||
n = next_name();
|
||||
}
|
||||
name n = mk_let_name(let_name(e));
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), n, new_type, new_val);
|
||||
let_fvars.push_back(new_fvar);
|
||||
m_fvars.push_back(new_fvar);
|
||||
|
|
@ -1224,6 +1229,110 @@ class csimp_fn {
|
|||
return none_expr();
|
||||
}
|
||||
|
||||
/* Return true iff `e` is of the form `fun (xs), let ys := ts in (ctor ...)`.
|
||||
This auxiliary method is used at try_inline_proj_instance_aux.
|
||||
It is a "quick" filter. */
|
||||
bool inline_proj_app_candidate(expr e) {
|
||||
while (is_lambda(e))
|
||||
e = binding_body(e);
|
||||
while (is_let(e))
|
||||
e = let_body(e);
|
||||
return static_cast<bool>(is_constructor_app(env(), e));
|
||||
}
|
||||
|
||||
/*
|
||||
Given `let x := f as in ... x.i`, where where `f` is defined as
|
||||
```
|
||||
def f (xs) :=
|
||||
...
|
||||
let y_i := t[xs] in
|
||||
...
|
||||
ctor ... y_i ...
|
||||
```
|
||||
reduce `x.i` into `t[as]`.
|
||||
`y_i` may depend on other let-declarations, but we only inline if the number
|
||||
of let-decl dependencies is less than `m_inline_threshold`.
|
||||
|
||||
Remark: this transformation is only applied before erasure.
|
||||
Remark: this transformation complements eager lambda lifting,
|
||||
and has been designed to optimize code such as:
|
||||
```
|
||||
def f (x : nat) : Pro (Nat -> Nat) (Nat -> Bool) :=
|
||||
((fun y, <code1 using x y>), (fun z, <code2 using x z>))
|
||||
```
|
||||
That is, `f` is "packing" functions in a structure and returning it.
|
||||
Now, consider the following application:
|
||||
```
|
||||
(f a).1 b
|
||||
```
|
||||
With eager lambda lifting, we transform `f` into
|
||||
```
|
||||
def f._elambda_1 (x y) : Nat :=
|
||||
<code1 using x y>
|
||||
def f._elambda_2 (x z) : Bool :=
|
||||
<code2 using x z>
|
||||
def f (x : nat) : Pro (Nat -> Nat) (Nat -> Bool) :=
|
||||
(f._elambda_1 x, f._elambda_2 x)
|
||||
```
|
||||
Then, with this transformation, we transform `(f a).1` into
|
||||
`f._elambda_1 a`, and then with application merge, we transform
|
||||
`(f a).1 b` into `f._elambda_1 a b`
|
||||
|
||||
See additional comments at `eager_lambda_lifting.cpp` */
|
||||
optional<expr> try_inline_proj_app(expr const & e, bool is_let_val) {
|
||||
lean_assert(is_proj(e));
|
||||
if (!m_before_erasure) return none_expr();
|
||||
if (!proj_idx(e).is_small()) return none_expr();
|
||||
unsigned idx = proj_idx(e).get_small_value();
|
||||
expr s = find(proj_expr(e));
|
||||
buffer<expr> s_args;
|
||||
expr const & s_fn = get_app_rev_args(s, s_args);
|
||||
if (!is_constant(s_fn)) return none_expr();
|
||||
if (has_noinline_attribute(env(), const_name(s_fn))) return none_expr();
|
||||
optional<constant_info> info = env().find(mk_cstage1_name(const_name(s_fn)));
|
||||
if (!info || !info->is_definition()) return none_expr();
|
||||
if (s_args.size() < get_num_nested_lambdas(info->get_value())) return none_expr();
|
||||
if (!inline_proj_app_candidate(info->get_value())) return none_expr();
|
||||
expr s_val = instantiate_value_lparams(*info, const_levels(s_fn));
|
||||
s_val = apply_beta(s_val, s_args.size(), s_args.data());
|
||||
buffer<expr> fvars;
|
||||
while (is_let(s_val)) {
|
||||
name n = mk_let_name(let_name(s_val));
|
||||
expr new_type = instantiate_rev(let_type(s_val), fvars.size(), fvars.data());
|
||||
expr new_val = instantiate_rev(let_value(s_val), fvars.size(), fvars.data());
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), n, new_type, new_val);
|
||||
fvars.push_back(new_fvar);
|
||||
s_val = let_body(s_val);
|
||||
}
|
||||
s_val = instantiate_rev(s_val, fvars.size(), fvars.data());
|
||||
lean_assert(is_constructor_app(env(), s_val));
|
||||
buffer<expr> k_args;
|
||||
expr const & k = get_app_args(s_val, k_args);
|
||||
constructor_val k_val = env().get(const_name(k)).to_constructor_val();
|
||||
lean_assert(k_val.get_nparams() + idx < k_args.size());
|
||||
expr val = k_args[k_val.get_nparams() + idx];
|
||||
buffer<expr> fvars_to_keep;
|
||||
name_hash_set used_fvars; /* Set of free variables names used */
|
||||
collect_used(val, used_fvars);
|
||||
unsigned i = fvars.size();
|
||||
while (i > 0) {
|
||||
i--;
|
||||
expr x = fvars[i];
|
||||
if (used_fvars.find(fvar_name(x)) != used_fvars.end()) {
|
||||
local_decl x_decl = m_lctx.get_local_decl(x);
|
||||
expr x_type = x_decl.get_type();
|
||||
expr x_val = *x_decl.get_value();
|
||||
collect_used(x_type, used_fvars);
|
||||
collect_used(x_val, used_fvars);
|
||||
fvars_to_keep.push_back(x);
|
||||
if (fvars_to_keep.size() > m_cfg.m_inline_threshold) return none_expr();
|
||||
}
|
||||
}
|
||||
std::reverse(fvars_to_keep.begin(), fvars_to_keep.end());
|
||||
val = m_lctx.mk_lambda(fvars_to_keep, val);
|
||||
return some_expr(visit(val, is_let_val));
|
||||
}
|
||||
|
||||
expr visit_proj(expr const & e, bool is_let_val) {
|
||||
expr s = find(proj_expr(e));
|
||||
|
||||
|
|
@ -1238,6 +1347,10 @@ class csimp_fn {
|
|||
return *r;
|
||||
}
|
||||
|
||||
if (optional<expr> r = try_inline_proj_app(e, is_let_val)) {
|
||||
return *r;
|
||||
}
|
||||
|
||||
expr new_arg = visit_arg(proj_expr(e));
|
||||
if (is_eqp(proj_expr(e), new_arg))
|
||||
return e;
|
||||
|
|
|
|||
|
|
@ -9,12 +9,10 @@ def mkNumPairKind : IO SyntaxNodeKind := nextKind `numPair
|
|||
def mkNumSetKind : IO SyntaxNodeKind := nextKind `numSet
|
||||
@[init mkNumSetKind] constant numSetKind : SyntaxNodeKind := default _
|
||||
|
||||
@[inline2]
|
||||
def numPair : BasicParser :=
|
||||
node numPairKind $
|
||||
"("; number; ","; number; ")"
|
||||
|
||||
@[inline2]
|
||||
def numSet : BasicParser :=
|
||||
node numSetKind $
|
||||
"{"; sepBy number ","; "}"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue