feat(library/compiler): inline small functions and join points

This commit is contained in:
Leonardo de Moura 2018-09-27 18:16:41 -07:00
parent 6786b2bcc8
commit 2d8d0d5a6c
3 changed files with 31 additions and 17 deletions

View file

@ -8,6 +8,7 @@ Author: Leonardo de Moura
#include "runtime/flet.h"
#include "kernel/type_checker.h"
#include "kernel/for_each_fn.h"
#include "kernel/find_fn.h"
#include "kernel/abstract.h"
#include "kernel/instantiate.h"
#include "library/util.h"
@ -21,7 +22,7 @@ Author: Leonardo de Moura
namespace lean {
csimp_cfg::csimp_cfg() {
m_inline = true;
m_inline_threshold = 4;
m_inline_threshold = 1;
m_float_cases_app = true;
m_float_cases = true;
m_float_cases_threshold = 40;
@ -68,14 +69,8 @@ class csimp_fn {
is_join_point_name(m_lctx.get_local_decl(fn).get_user_name());
}
/* Very simple predicate used to decide whether we should inline joint-points or not.
TODO(Leo): improve */
bool is_small(expr const & e) const {
if (is_app(e) && !is_cases_on_app(env(), e))
return true;
if (is_lambda(e))
return is_small(binding_body(e));
return false;
bool is_small_join_point(expr const & e) const {
return get_lcnf_size(env(), e) <= m_cfg.m_inline_jp_threshold;
}
expr find(expr const & e, bool skip_mdata = true) const {
@ -84,7 +79,7 @@ class csimp_fn {
if (optional<expr> v = decl->get_value()) {
if (!is_join_point_name(decl->get_user_name()))
return find(*v, skip_mdata);
else if (is_small(*v))
else if (is_small_join_point(*v))
return find(*v, skip_mdata);
}
}
@ -1063,15 +1058,29 @@ class csimp_fn {
return mk_cast(tc, r_type, e_type, r);
}
/* We don't inline recursive functions.
TODO(Leo): this predicate does not handle mutual recursion.
We need a better solution. Example: we tag which definitions are recursive when we create them. */
bool is_recursive(name const & c) {
constant_info info = env().get(c);
return static_cast<bool>(::lean::find(info.get_value(), [&](expr const & e, unsigned) {
return is_constant(e) && const_name(e) == c.get_prefix();
}));
}
expr try_inline(expr const & fn, expr const & e, bool is_let_val) {
lean_assert(is_constant(fn));
lean_assert(is_eqp(find(get_app_fn(e)), fn));
if (!m_cfg.m_inline) return e;
if (has_noinline_attribute(env(), const_name(fn))) return e;
optional<constant_info> info = env().find(mk_cstage1_name(const_name(fn)));
name c = mk_cstage1_name(const_name(fn));
optional<constant_info> info = env().find(c);
if (!info || !info->is_definition()) return e;
if (get_app_num_args(e) < get_num_nested_lambdas(info->get_value())) return e;
/* TODO(Leo): check size and whether function is boring or not. */
if (!has_inline_attribute(env(), const_name(fn))) return e;
if (!has_inline_attribute(env(), const_name(fn)) &&
get_lcnf_size(env(), info->get_value()) > m_cfg.m_inline_threshold)
return e;
if (is_recursive(c)) return e;
expr new_fn = instantiate_value_lparams(*info, const_levels(fn));
return beta_reduce(new_fn, e, is_let_val);
}

View file

@ -8,11 +8,19 @@ Author: Leonardo de Moura
#include "kernel/environment.h"
namespace lean {
struct csimp_cfg {
/* If `m_inline` == false, then we will not inline `c` even if it is marked with the attribute `[inline]`. */
bool m_inline;
/* We inline "cheap" functions. We say a function is cheap if `get_lcnf_size(val) < m_inline_threshold`,
and it is not marked as `[noinline]`. */
unsigned m_inline_threshold;
/* Enable float cases_on from application. Remark: this transformation is essential for monadic code. */
bool m_float_cases_app;
/* Enable float cases_on from cases_on and other expressions. */
bool m_float_cases;
/* We only perform float cases_on from cases_on and other expression if the potential code blowup is smaller
than m_float_cases_threshold. */
unsigned m_float_cases_threshold;
/* We inline join-points that are smaller m_inline_threshold. */
unsigned m_inline_jp_threshold;
public:
csimp_cfg();

View file

@ -358,10 +358,7 @@ unsigned get_lcnf_size(environment const & env, expr e) {
case expr_kind::MData:
return 1;
case expr_kind::Const:
if (is_constructor(env, const_name(e)))
return 0;
else
return 1;
return 1;
case expr_kind::Lambda:
r = 1;
while (is_lambda(e)) {