feat(library/compiler/csimp): add transformation to complement eager lambda lifting

This commit is contained in:
Leonardo de Moura 2019-04-18 13:12:11 -07:00
parent e3dfc73b3a
commit 0ea944ad9f
2 changed files with 120 additions and 9 deletions

View file

@ -995,6 +995,17 @@ class csimp_fn {
return mk_let_core(entries, e); return mk_let_core(entries, e);
} }
name mk_let_name(name const & n) {
if (is_internal_name(n)) {
if (is_join_point_name(n))
return next_jp_name();
else
return next_name();
} else {
return n;
}
}
expr visit_let(expr e) { expr visit_let(expr e) {
buffer<expr> let_fvars; buffer<expr> let_fvars;
while (is_let(e)) { while (is_let(e)) {
@ -1003,13 +1014,7 @@ class csimp_fn {
if (is_lcnf_atom(new_val)) { if (is_lcnf_atom(new_val)) {
let_fvars.push_back(new_val); let_fvars.push_back(new_val);
} else { } else {
name n = let_name(e); name n = mk_let_name(let_name(e));
if (is_internal_name(n)) {
if (is_join_point_name(n))
n = next_jp_name();
else
n = next_name();
}
expr new_fvar = m_lctx.mk_local_decl(ngen(), n, new_type, new_val); expr new_fvar = m_lctx.mk_local_decl(ngen(), n, new_type, new_val);
let_fvars.push_back(new_fvar); let_fvars.push_back(new_fvar);
m_fvars.push_back(new_fvar); m_fvars.push_back(new_fvar);
@ -1224,6 +1229,110 @@ class csimp_fn {
return none_expr(); return none_expr();
} }
/* Return true iff `e` is of the form `fun (xs), let ys := ts in (ctor ...)`.
This auxiliary method is used at try_inline_proj_instance_aux.
It is a "quick" filter. */
bool inline_proj_app_candidate(expr e) {
while (is_lambda(e))
e = binding_body(e);
while (is_let(e))
e = let_body(e);
return static_cast<bool>(is_constructor_app(env(), e));
}
/*
Given `let x := f as in ... x.i`, where where `f` is defined as
```
def f (xs) :=
...
let y_i := t[xs] in
...
ctor ... y_i ...
```
reduce `x.i` into `t[as]`.
`y_i` may depend on other let-declarations, but we only inline if the number
of let-decl dependencies is less than `m_inline_threshold`.
Remark: this transformation is only applied before erasure.
Remark: this transformation complements eager lambda lifting,
and has been designed to optimize code such as:
```
def f (x : nat) : Pro (Nat -> Nat) (Nat -> Bool) :=
((fun y, <code1 using x y>), (fun z, <code2 using x z>))
```
That is, `f` is "packing" functions in a structure and returning it.
Now, consider the following application:
```
(f a).1 b
```
With eager lambda lifting, we transform `f` into
```
def f._elambda_1 (x y) : Nat :=
<code1 using x y>
def f._elambda_2 (x z) : Bool :=
<code2 using x z>
def f (x : nat) : Pro (Nat -> Nat) (Nat -> Bool) :=
(f._elambda_1 x, f._elambda_2 x)
```
Then, with this transformation, we transform `(f a).1` into
`f._elambda_1 a`, and then with application merge, we transform
`(f a).1 b` into `f._elambda_1 a b`
See additional comments at `eager_lambda_lifting.cpp` */
optional<expr> try_inline_proj_app(expr const & e, bool is_let_val) {
lean_assert(is_proj(e));
if (!m_before_erasure) return none_expr();
if (!proj_idx(e).is_small()) return none_expr();
unsigned idx = proj_idx(e).get_small_value();
expr s = find(proj_expr(e));
buffer<expr> s_args;
expr const & s_fn = get_app_rev_args(s, s_args);
if (!is_constant(s_fn)) return none_expr();
if (has_noinline_attribute(env(), const_name(s_fn))) return none_expr();
optional<constant_info> info = env().find(mk_cstage1_name(const_name(s_fn)));
if (!info || !info->is_definition()) return none_expr();
if (s_args.size() < get_num_nested_lambdas(info->get_value())) return none_expr();
if (!inline_proj_app_candidate(info->get_value())) return none_expr();
expr s_val = instantiate_value_lparams(*info, const_levels(s_fn));
s_val = apply_beta(s_val, s_args.size(), s_args.data());
buffer<expr> fvars;
while (is_let(s_val)) {
name n = mk_let_name(let_name(s_val));
expr new_type = instantiate_rev(let_type(s_val), fvars.size(), fvars.data());
expr new_val = instantiate_rev(let_value(s_val), fvars.size(), fvars.data());
expr new_fvar = m_lctx.mk_local_decl(ngen(), n, new_type, new_val);
fvars.push_back(new_fvar);
s_val = let_body(s_val);
}
s_val = instantiate_rev(s_val, fvars.size(), fvars.data());
lean_assert(is_constructor_app(env(), s_val));
buffer<expr> k_args;
expr const & k = get_app_args(s_val, k_args);
constructor_val k_val = env().get(const_name(k)).to_constructor_val();
lean_assert(k_val.get_nparams() + idx < k_args.size());
expr val = k_args[k_val.get_nparams() + idx];
buffer<expr> fvars_to_keep;
name_hash_set used_fvars; /* Set of free variables names used */
collect_used(val, used_fvars);
unsigned i = fvars.size();
while (i > 0) {
i--;
expr x = fvars[i];
if (used_fvars.find(fvar_name(x)) != used_fvars.end()) {
local_decl x_decl = m_lctx.get_local_decl(x);
expr x_type = x_decl.get_type();
expr x_val = *x_decl.get_value();
collect_used(x_type, used_fvars);
collect_used(x_val, used_fvars);
fvars_to_keep.push_back(x);
if (fvars_to_keep.size() > m_cfg.m_inline_threshold) return none_expr();
}
}
std::reverse(fvars_to_keep.begin(), fvars_to_keep.end());
val = m_lctx.mk_lambda(fvars_to_keep, val);
return some_expr(visit(val, is_let_val));
}
expr visit_proj(expr const & e, bool is_let_val) { expr visit_proj(expr const & e, bool is_let_val) {
expr s = find(proj_expr(e)); expr s = find(proj_expr(e));
@ -1238,6 +1347,10 @@ class csimp_fn {
return *r; return *r;
} }
if (optional<expr> r = try_inline_proj_app(e, is_let_val)) {
return *r;
}
expr new_arg = visit_arg(proj_expr(e)); expr new_arg = visit_arg(proj_expr(e));
if (is_eqp(proj_expr(e), new_arg)) if (is_eqp(proj_expr(e), new_arg))
return e; return e;

View file

@ -9,12 +9,10 @@ def mkNumPairKind : IO SyntaxNodeKind := nextKind `numPair
def mkNumSetKind : IO SyntaxNodeKind := nextKind `numSet def mkNumSetKind : IO SyntaxNodeKind := nextKind `numSet
@[init mkNumSetKind] constant numSetKind : SyntaxNodeKind := default _ @[init mkNumSetKind] constant numSetKind : SyntaxNodeKind := default _
@[inline2]
def numPair : BasicParser := def numPair : BasicParser :=
node numPairKind $ node numPairKind $
"("; number; ","; number; ")" "("; number; ","; number; ")"
@[inline2]
def numSet : BasicParser := def numSet : BasicParser :=
node numSetKind $ node numSetKind $
"{"; sepBy number ","; "}" "{"; sepBy number ","; "}"