fix: inline loop
This commit is contained in:
parent
768f2642bd
commit
d494756d00
2 changed files with 76 additions and 9 deletions
|
|
@ -1606,14 +1606,53 @@ class csimp_fn {
|
|||
}
|
||||
}
|
||||
|
||||
/* We don't inline recursive functions.
|
||||
TODO(Leo): this predicate does not handle mutual recursion.
|
||||
We need a better solution. Example: we tag which definitions are recursive when we create them. */
|
||||
struct is_recursive_fn {
|
||||
environment const & m_env;
|
||||
csimp_cfg const & m_cfg;
|
||||
bool m_before_erasure;
|
||||
name m_target;
|
||||
name_set m_visited;
|
||||
|
||||
is_recursive_fn(environment const & env, csimp_cfg const & cfg, bool before_erasure):
|
||||
m_env(env), m_cfg(cfg), m_before_erasure(before_erasure) {
|
||||
}
|
||||
|
||||
optional<constant_info> is_inline_candidate(name const & f) {
|
||||
name c = m_before_erasure ? mk_cstage1_name(f) : mk_cstage2_name(f);
|
||||
optional<constant_info> info = m_env.find(c);
|
||||
if (!info || !info->is_definition()) {
|
||||
return optional<constant_info>();
|
||||
} else if (has_inline_attribute(m_env, f)) {
|
||||
return info;
|
||||
} else if (get_lcnf_size(m_env, info->get_value()) <= m_cfg.m_inline_threshold) {
|
||||
return info;
|
||||
} else {
|
||||
return optional<constant_info>();
|
||||
}
|
||||
}
|
||||
|
||||
bool visit(name const & f) {
|
||||
if (optional<constant_info> info = is_inline_candidate(f)) {
|
||||
if (m_visited.contains(f))
|
||||
return true;
|
||||
m_visited.insert(f);
|
||||
return static_cast<bool>(::lean::find(info->get_value(), [&](expr const & e, unsigned) {
|
||||
return is_constant(e) && (const_name(e) == m_target || visit(const_name(e)));
|
||||
}));
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool operator()(name const & f) {
|
||||
m_target = f;
|
||||
return visit(f);
|
||||
}
|
||||
};
|
||||
|
||||
/* We don't inline recursive functions. */
|
||||
bool is_recursive(name const & c) {
|
||||
constant_info info = env().get(c);
|
||||
return static_cast<bool>(::lean::find(info.get_value(), [&](expr const & e, unsigned) {
|
||||
return is_constant(e) && const_name(e) == c.get_prefix();
|
||||
}));
|
||||
return is_recursive_fn(env(), m_cfg, m_before_erasure)(c);
|
||||
}
|
||||
|
||||
bool uses_unsafe_inductive(name const & c) {
|
||||
|
|
@ -1690,7 +1729,7 @@ class csimp_fn {
|
|||
is_constant(e))) { /* We only inline constants if they are marked with the `[inline]` or `[inline_if_reduce]` attrs */
|
||||
return none_expr();
|
||||
}
|
||||
if (!inline_if_reduce_attr && is_recursive(c)) return none_expr();
|
||||
if (!inline_if_reduce_attr && is_recursive(const_name(fn))) return none_expr();
|
||||
if (uses_unsafe_inductive(c)) return none_expr();
|
||||
expr new_fn = instantiate_value_lparams(*info, const_levels(fn));
|
||||
if (inline_if_reduce_attr && !inline_attr) {
|
||||
|
|
@ -1707,7 +1746,7 @@ class csimp_fn {
|
|||
unsigned arity = get_num_nested_lambdas(info->get_value());
|
||||
if (get_app_num_args(e) < arity || arity == 0) return none_expr();
|
||||
if (get_lcnf_size(env(), info->get_value()) > m_cfg.m_inline_threshold) return none_expr();
|
||||
if (is_recursive(c)) return none_expr();
|
||||
if (is_recursive(const_name(fn))) return none_expr();
|
||||
if (uses_unsafe_inductive(c)) return none_expr();
|
||||
return some_expr(beta_reduce(info->get_value(), e, is_let_val));
|
||||
}
|
||||
|
|
|
|||
28
tests/lean/run/inlineLoop.lean
Normal file
28
tests/lean/run/inlineLoop.lean
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
namespace Test1
|
||||
mutual
|
||||
|
||||
partial def f (a : Nat) : Nat := g a
|
||||
|
||||
partial def g (a : Nat) : Nat := f a
|
||||
|
||||
end
|
||||
end Test1
|
||||
|
||||
namespace Test2
|
||||
mutual
|
||||
|
||||
@[inline]
|
||||
partial def f (a : Nat) : Nat := g a + g a + g a + g a
|
||||
|
||||
@[inline]
|
||||
partial def g (a : Nat) : Nat := f a + f a + f a + f a
|
||||
|
||||
end
|
||||
end Test2
|
||||
|
||||
namespace Test3
|
||||
|
||||
partial def unsafeFn1 {m} [Monad m] (a : α) : m α :=
|
||||
unsafeFn1 a
|
||||
|
||||
end Test3
|
||||
Loading…
Add table
Reference in a new issue