fix: inline loop

This commit is contained in:
Leonardo de Moura 2021-02-04 13:08:28 -08:00
parent 768f2642bd
commit d494756d00
2 changed files with 76 additions and 9 deletions

View file

@ -1606,14 +1606,53 @@ class csimp_fn {
}
}
/* We don't inline recursive functions.
TODO(Leo): this predicate does not handle mutual recursion.
We need a better solution. Example: we tag which definitions are recursive when we create them. */
struct is_recursive_fn {
environment const & m_env;
csimp_cfg const & m_cfg;
bool m_before_erasure;
name m_target;
name_set m_visited;
is_recursive_fn(environment const & env, csimp_cfg const & cfg, bool before_erasure):
m_env(env), m_cfg(cfg), m_before_erasure(before_erasure) {
}
optional<constant_info> is_inline_candidate(name const & f) {
name c = m_before_erasure ? mk_cstage1_name(f) : mk_cstage2_name(f);
optional<constant_info> info = m_env.find(c);
if (!info || !info->is_definition()) {
return optional<constant_info>();
} else if (has_inline_attribute(m_env, f)) {
return info;
} else if (get_lcnf_size(m_env, info->get_value()) <= m_cfg.m_inline_threshold) {
return info;
} else {
return optional<constant_info>();
}
}
bool visit(name const & f) {
if (optional<constant_info> info = is_inline_candidate(f)) {
if (m_visited.contains(f))
return true;
m_visited.insert(f);
return static_cast<bool>(::lean::find(info->get_value(), [&](expr const & e, unsigned) {
return is_constant(e) && (const_name(e) == m_target || visit(const_name(e)));
}));
} else {
return false;
}
}
bool operator()(name const & f) {
m_target = f;
return visit(f);
}
};
/* We don't inline recursive functions. */
bool is_recursive(name const & c) {
constant_info info = env().get(c);
return static_cast<bool>(::lean::find(info.get_value(), [&](expr const & e, unsigned) {
return is_constant(e) && const_name(e) == c.get_prefix();
}));
return is_recursive_fn(env(), m_cfg, m_before_erasure)(c);
}
bool uses_unsafe_inductive(name const & c) {
@ -1690,7 +1729,7 @@ class csimp_fn {
is_constant(e))) { /* We only inline constants if they are marked with the `[inline]` or `[inline_if_reduce]` attrs */
return none_expr();
}
if (!inline_if_reduce_attr && is_recursive(c)) return none_expr();
if (!inline_if_reduce_attr && is_recursive(const_name(fn))) return none_expr();
if (uses_unsafe_inductive(c)) return none_expr();
expr new_fn = instantiate_value_lparams(*info, const_levels(fn));
if (inline_if_reduce_attr && !inline_attr) {
@ -1707,7 +1746,7 @@ class csimp_fn {
unsigned arity = get_num_nested_lambdas(info->get_value());
if (get_app_num_args(e) < arity || arity == 0) return none_expr();
if (get_lcnf_size(env(), info->get_value()) > m_cfg.m_inline_threshold) return none_expr();
if (is_recursive(c)) return none_expr();
if (is_recursive(const_name(fn))) return none_expr();
if (uses_unsafe_inductive(c)) return none_expr();
return some_expr(beta_reduce(info->get_value(), e, is_let_val));
}

View file

@ -0,0 +1,28 @@
namespace Test1
mutual
partial def f (a : Nat) : Nat := g a
partial def g (a : Nat) : Nat := f a
end
end Test1
namespace Test2
mutual
@[inline]
partial def f (a : Nat) : Nat := g a + g a + g a + g a
@[inline]
partial def g (a : Nat) : Nat := f a + f a + f a + f a
end
end Test2
namespace Test3
partial def unsafeFn1 {m} [Monad m] (a : α) : m α :=
unsafeFn1 a
end Test3