diff --git a/src/library/compiler/csimp.cpp b/src/library/compiler/csimp.cpp index bdc6c4fbba..cd799722d5 100644 --- a/src/library/compiler/csimp.cpp +++ b/src/library/compiler/csimp.cpp @@ -1606,14 +1606,53 @@ class csimp_fn { } } - /* We don't inline recursive functions. - TODO(Leo): this predicate does not handle mutual recursion. - We need a better solution. Example: we tag which definitions are recursive when we create them. */ + struct is_recursive_fn { + environment const & m_env; + csimp_cfg const & m_cfg; + bool m_before_erasure; + name m_target; + name_set m_visited; + + is_recursive_fn(environment const & env, csimp_cfg const & cfg, bool before_erasure): + m_env(env), m_cfg(cfg), m_before_erasure(before_erasure) { + } + + optional is_inline_candidate(name const & f) { + name c = m_before_erasure ? mk_cstage1_name(f) : mk_cstage2_name(f); + optional info = m_env.find(c); + if (!info || !info->is_definition()) { + return optional(); + } else if (has_inline_attribute(m_env, f)) { + return info; + } else if (get_lcnf_size(m_env, info->get_value()) <= m_cfg.m_inline_threshold) { + return info; + } else { + return optional(); + } + } + + bool visit(name const & f) { + if (optional info = is_inline_candidate(f)) { + if (m_visited.contains(f)) + return true; + m_visited.insert(f); + return static_cast(::lean::find(info->get_value(), [&](expr const & e, unsigned) { + return is_constant(e) && (const_name(e) == m_target || visit(const_name(e))); + })); + } else { + return false; + } + } + + bool operator()(name const & f) { + m_target = f; + return visit(f); + } + }; + + /* We don't inline recursive functions. */ bool is_recursive(name const & c) { - constant_info info = env().get(c); - return static_cast(::lean::find(info.get_value(), [&](expr const & e, unsigned) { - return is_constant(e) && const_name(e) == c.get_prefix(); - })); + return is_recursive_fn(env(), m_cfg, m_before_erasure)(c); } bool uses_unsafe_inductive(name const & c) { @@ -1690,7 +1729,7 @@ class csimp_fn { is_constant(e))) { /* We only inline constants if they are marked with the `[inline]` or `[inline_if_reduce]` attrs */ return none_expr(); } - if (!inline_if_reduce_attr && is_recursive(c)) return none_expr(); + if (!inline_if_reduce_attr && is_recursive(const_name(fn))) return none_expr(); if (uses_unsafe_inductive(c)) return none_expr(); expr new_fn = instantiate_value_lparams(*info, const_levels(fn)); if (inline_if_reduce_attr && !inline_attr) { @@ -1707,7 +1746,7 @@ class csimp_fn { unsigned arity = get_num_nested_lambdas(info->get_value()); if (get_app_num_args(e) < arity || arity == 0) return none_expr(); if (get_lcnf_size(env(), info->get_value()) > m_cfg.m_inline_threshold) return none_expr(); - if (is_recursive(c)) return none_expr(); + if (is_recursive(const_name(fn))) return none_expr(); if (uses_unsafe_inductive(c)) return none_expr(); return some_expr(beta_reduce(info->get_value(), e, is_let_val)); } diff --git a/tests/lean/run/inlineLoop.lean b/tests/lean/run/inlineLoop.lean new file mode 100644 index 0000000000..514bfe9d76 --- /dev/null +++ b/tests/lean/run/inlineLoop.lean @@ -0,0 +1,28 @@ +namespace Test1 +mutual + + partial def f (a : Nat) : Nat := g a + + partial def g (a : Nat) : Nat := f a + +end +end Test1 + +namespace Test2 +mutual + + @[inline] + partial def f (a : Nat) : Nat := g a + g a + g a + g a + + @[inline] + partial def g (a : Nat) : Nat := f a + f a + f a + f a + +end +end Test2 + +namespace Test3 + +partial def unsafeFn1 {m} [Monad m] (a : α) : m α := + unsafeFn1 a + +end Test3