From 2d8d0d5a6cff6820b2fc212fb7df577ff866a730 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 27 Sep 2018 18:16:41 -0700 Subject: [PATCH] feat(library/compiler): inline small functions and join points --- src/library/compiler/csimp.cpp | 35 +++++++++++++++++++++------------- src/library/compiler/csimp.h | 8 ++++++++ src/library/compiler/util.cpp | 5 +---- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/library/compiler/csimp.cpp b/src/library/compiler/csimp.cpp index 825a940f7c..aaeb20c8e5 100644 --- a/src/library/compiler/csimp.cpp +++ b/src/library/compiler/csimp.cpp @@ -8,6 +8,7 @@ Author: Leonardo de Moura #include "runtime/flet.h" #include "kernel/type_checker.h" #include "kernel/for_each_fn.h" +#include "kernel/find_fn.h" #include "kernel/abstract.h" #include "kernel/instantiate.h" #include "library/util.h" @@ -21,7 +22,7 @@ Author: Leonardo de Moura namespace lean { csimp_cfg::csimp_cfg() { m_inline = true; - m_inline_threshold = 4; + m_inline_threshold = 1; m_float_cases_app = true; m_float_cases = true; m_float_cases_threshold = 40; @@ -68,14 +69,8 @@ class csimp_fn { is_join_point_name(m_lctx.get_local_decl(fn).get_user_name()); } - /* Very simple predicate used to decide whether we should inline joint-points or not. - TODO(Leo): improve */ - bool is_small(expr const & e) const { - if (is_app(e) && !is_cases_on_app(env(), e)) - return true; - if (is_lambda(e)) - return is_small(binding_body(e)); - return false; + bool is_small_join_point(expr const & e) const { + return get_lcnf_size(env(), e) <= m_cfg.m_inline_jp_threshold; } expr find(expr const & e, bool skip_mdata = true) const { @@ -84,7 +79,7 @@ class csimp_fn { if (optional v = decl->get_value()) { if (!is_join_point_name(decl->get_user_name())) return find(*v, skip_mdata); - else if (is_small(*v)) + else if (is_small_join_point(*v)) return find(*v, skip_mdata); } } @@ -1063,15 +1058,29 @@ class csimp_fn { return mk_cast(tc, r_type, e_type, r); } + /* We don't inline recursive functions. + TODO(Leo): this predicate does not handle mutual recursion. + We need a better solution. Example: we tag which definitions are recursive when we create them. */ + bool is_recursive(name const & c) { + constant_info info = env().get(c); + return static_cast(::lean::find(info.get_value(), [&](expr const & e, unsigned) { + return is_constant(e) && const_name(e) == c.get_prefix(); + })); + } + expr try_inline(expr const & fn, expr const & e, bool is_let_val) { lean_assert(is_constant(fn)); lean_assert(is_eqp(find(get_app_fn(e)), fn)); + if (!m_cfg.m_inline) return e; if (has_noinline_attribute(env(), const_name(fn))) return e; - optional info = env().find(mk_cstage1_name(const_name(fn))); + name c = mk_cstage1_name(const_name(fn)); + optional info = env().find(c); if (!info || !info->is_definition()) return e; if (get_app_num_args(e) < get_num_nested_lambdas(info->get_value())) return e; - /* TODO(Leo): check size and whether function is boring or not. */ - if (!has_inline_attribute(env(), const_name(fn))) return e; + if (!has_inline_attribute(env(), const_name(fn)) && + get_lcnf_size(env(), info->get_value()) > m_cfg.m_inline_threshold) + return e; + if (is_recursive(c)) return e; expr new_fn = instantiate_value_lparams(*info, const_levels(fn)); return beta_reduce(new_fn, e, is_let_val); } diff --git a/src/library/compiler/csimp.h b/src/library/compiler/csimp.h index 14de6a87df..ccac0b826e 100644 --- a/src/library/compiler/csimp.h +++ b/src/library/compiler/csimp.h @@ -8,11 +8,19 @@ Author: Leonardo de Moura #include "kernel/environment.h" namespace lean { struct csimp_cfg { + /* If `m_inline` == false, then we will not inline `c` even if it is marked with the attribute `[inline]`. */ bool m_inline; + /* We inline "cheap" functions. We say a function is cheap if `get_lcnf_size(val) < m_inline_threshold`, + and it is not marked as `[noinline]`. */ unsigned m_inline_threshold; + /* Enable float cases_on from application. Remark: this transformation is essential for monadic code. */ bool m_float_cases_app; + /* Enable float cases_on from cases_on and other expressions. */ bool m_float_cases; + /* We only perform float cases_on from cases_on and other expression if the potential code blowup is smaller + than m_float_cases_threshold. */ unsigned m_float_cases_threshold; + /* We inline join-points that are smaller m_inline_threshold. */ unsigned m_inline_jp_threshold; public: csimp_cfg(); diff --git a/src/library/compiler/util.cpp b/src/library/compiler/util.cpp index a51699bf76..c20741345c 100644 --- a/src/library/compiler/util.cpp +++ b/src/library/compiler/util.cpp @@ -358,10 +358,7 @@ unsigned get_lcnf_size(environment const & env, expr e) { case expr_kind::MData: return 1; case expr_kind::Const: - if (is_constructor(env, const_name(e))) - return 0; - else - return 1; + return 1; case expr_kind::Lambda: r = 1; while (is_lambda(e)) {