diff --git a/src/library/compiler/csimp.cpp b/src/library/compiler/csimp.cpp index d4ca1fddbf..b1c487b669 100644 --- a/src/library/compiler/csimp.cpp +++ b/src/library/compiler/csimp.cpp @@ -995,6 +995,17 @@ class csimp_fn { return mk_let_core(entries, e); } + name mk_let_name(name const & n) { + if (is_internal_name(n)) { + if (is_join_point_name(n)) + return next_jp_name(); + else + return next_name(); + } else { + return n; + } + } + expr visit_let(expr e) { buffer let_fvars; while (is_let(e)) { @@ -1003,13 +1014,7 @@ class csimp_fn { if (is_lcnf_atom(new_val)) { let_fvars.push_back(new_val); } else { - name n = let_name(e); - if (is_internal_name(n)) { - if (is_join_point_name(n)) - n = next_jp_name(); - else - n = next_name(); - } + name n = mk_let_name(let_name(e)); expr new_fvar = m_lctx.mk_local_decl(ngen(), n, new_type, new_val); let_fvars.push_back(new_fvar); m_fvars.push_back(new_fvar); @@ -1224,6 +1229,110 @@ class csimp_fn { return none_expr(); } + /* Return true iff `e` is of the form `fun (xs), let ys := ts in (ctor ...)`. + This auxiliary method is used at try_inline_proj_instance_aux. + It is a "quick" filter. */ + bool inline_proj_app_candidate(expr e) { + while (is_lambda(e)) + e = binding_body(e); + while (is_let(e)) + e = let_body(e); + return static_cast(is_constructor_app(env(), e)); + } + + /* + Given `let x := f as in ... x.i`, where where `f` is defined as + ``` + def f (xs) := + ... + let y_i := t[xs] in + ... + ctor ... y_i ... + ``` + reduce `x.i` into `t[as]`. + `y_i` may depend on other let-declarations, but we only inline if the number + of let-decl dependencies is less than `m_inline_threshold`. + + Remark: this transformation is only applied before erasure. + Remark: this transformation complements eager lambda lifting, + and has been designed to optimize code such as: + ``` + def f (x : nat) : Pro (Nat -> Nat) (Nat -> Bool) := + ((fun y, ), (fun z, )) + ``` + That is, `f` is "packing" functions in a structure and returning it. + Now, consider the following application: + ``` + (f a).1 b + ``` + With eager lambda lifting, we transform `f` into + ``` + def f._elambda_1 (x y) : Nat := + + def f._elambda_2 (x z) : Bool := + + def f (x : nat) : Pro (Nat -> Nat) (Nat -> Bool) := + (f._elambda_1 x, f._elambda_2 x) + ``` + Then, with this transformation, we transform `(f a).1` into + `f._elambda_1 a`, and then with application merge, we transform + `(f a).1 b` into `f._elambda_1 a b` + + See additional comments at `eager_lambda_lifting.cpp` */ + optional try_inline_proj_app(expr const & e, bool is_let_val) { + lean_assert(is_proj(e)); + if (!m_before_erasure) return none_expr(); + if (!proj_idx(e).is_small()) return none_expr(); + unsigned idx = proj_idx(e).get_small_value(); + expr s = find(proj_expr(e)); + buffer s_args; + expr const & s_fn = get_app_rev_args(s, s_args); + if (!is_constant(s_fn)) return none_expr(); + if (has_noinline_attribute(env(), const_name(s_fn))) return none_expr(); + optional info = env().find(mk_cstage1_name(const_name(s_fn))); + if (!info || !info->is_definition()) return none_expr(); + if (s_args.size() < get_num_nested_lambdas(info->get_value())) return none_expr(); + if (!inline_proj_app_candidate(info->get_value())) return none_expr(); + expr s_val = instantiate_value_lparams(*info, const_levels(s_fn)); + s_val = apply_beta(s_val, s_args.size(), s_args.data()); + buffer fvars; + while (is_let(s_val)) { + name n = mk_let_name(let_name(s_val)); + expr new_type = instantiate_rev(let_type(s_val), fvars.size(), fvars.data()); + expr new_val = instantiate_rev(let_value(s_val), fvars.size(), fvars.data()); + expr new_fvar = m_lctx.mk_local_decl(ngen(), n, new_type, new_val); + fvars.push_back(new_fvar); + s_val = let_body(s_val); + } + s_val = instantiate_rev(s_val, fvars.size(), fvars.data()); + lean_assert(is_constructor_app(env(), s_val)); + buffer k_args; + expr const & k = get_app_args(s_val, k_args); + constructor_val k_val = env().get(const_name(k)).to_constructor_val(); + lean_assert(k_val.get_nparams() + idx < k_args.size()); + expr val = k_args[k_val.get_nparams() + idx]; + buffer fvars_to_keep; + name_hash_set used_fvars; /* Set of free variables names used */ + collect_used(val, used_fvars); + unsigned i = fvars.size(); + while (i > 0) { + i--; + expr x = fvars[i]; + if (used_fvars.find(fvar_name(x)) != used_fvars.end()) { + local_decl x_decl = m_lctx.get_local_decl(x); + expr x_type = x_decl.get_type(); + expr x_val = *x_decl.get_value(); + collect_used(x_type, used_fvars); + collect_used(x_val, used_fvars); + fvars_to_keep.push_back(x); + if (fvars_to_keep.size() > m_cfg.m_inline_threshold) return none_expr(); + } + } + std::reverse(fvars_to_keep.begin(), fvars_to_keep.end()); + val = m_lctx.mk_lambda(fvars_to_keep, val); + return some_expr(visit(val, is_let_val)); + } + expr visit_proj(expr const & e, bool is_let_val) { expr s = find(proj_expr(e)); @@ -1238,6 +1347,10 @@ class csimp_fn { return *r; } + if (optional r = try_inline_proj_app(e, is_let_val)) { + return *r; + } + expr new_arg = visit_arg(proj_expr(e)); if (is_eqp(proj_expr(e), new_arg)) return e; diff --git a/tests/playground/parser/test1.lean b/tests/playground/parser/test1.lean index 4338507344..fe08a71a35 100644 --- a/tests/playground/parser/test1.lean +++ b/tests/playground/parser/test1.lean @@ -9,12 +9,10 @@ def mkNumPairKind : IO SyntaxNodeKind := nextKind `numPair def mkNumSetKind : IO SyntaxNodeKind := nextKind `numSet @[init mkNumSetKind] constant numSetKind : SyntaxNodeKind := default _ -@[inline2] def numPair : BasicParser := node numPairKind $ "("; number; ","; number; ")" -@[inline2] def numSet : BasicParser := node numSetKind $ "{"; sepBy number ","; "}"