feat(library/compiler): inline small functions and join points
This commit is contained in:
parent
6786b2bcc8
commit
2d8d0d5a6c
3 changed files with 31 additions and 17 deletions
|
|
@ -8,6 +8,7 @@ Author: Leonardo de Moura
|
|||
#include "runtime/flet.h"
|
||||
#include "kernel/type_checker.h"
|
||||
#include "kernel/for_each_fn.h"
|
||||
#include "kernel/find_fn.h"
|
||||
#include "kernel/abstract.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "library/util.h"
|
||||
|
|
@ -21,7 +22,7 @@ Author: Leonardo de Moura
|
|||
namespace lean {
|
||||
csimp_cfg::csimp_cfg() {
|
||||
m_inline = true;
|
||||
m_inline_threshold = 4;
|
||||
m_inline_threshold = 1;
|
||||
m_float_cases_app = true;
|
||||
m_float_cases = true;
|
||||
m_float_cases_threshold = 40;
|
||||
|
|
@ -68,14 +69,8 @@ class csimp_fn {
|
|||
is_join_point_name(m_lctx.get_local_decl(fn).get_user_name());
|
||||
}
|
||||
|
||||
/* Very simple predicate used to decide whether we should inline joint-points or not.
|
||||
TODO(Leo): improve */
|
||||
bool is_small(expr const & e) const {
|
||||
if (is_app(e) && !is_cases_on_app(env(), e))
|
||||
return true;
|
||||
if (is_lambda(e))
|
||||
return is_small(binding_body(e));
|
||||
return false;
|
||||
bool is_small_join_point(expr const & e) const {
|
||||
return get_lcnf_size(env(), e) <= m_cfg.m_inline_jp_threshold;
|
||||
}
|
||||
|
||||
expr find(expr const & e, bool skip_mdata = true) const {
|
||||
|
|
@ -84,7 +79,7 @@ class csimp_fn {
|
|||
if (optional<expr> v = decl->get_value()) {
|
||||
if (!is_join_point_name(decl->get_user_name()))
|
||||
return find(*v, skip_mdata);
|
||||
else if (is_small(*v))
|
||||
else if (is_small_join_point(*v))
|
||||
return find(*v, skip_mdata);
|
||||
}
|
||||
}
|
||||
|
|
@ -1063,15 +1058,29 @@ class csimp_fn {
|
|||
return mk_cast(tc, r_type, e_type, r);
|
||||
}
|
||||
|
||||
/* We don't inline recursive functions.
|
||||
TODO(Leo): this predicate does not handle mutual recursion.
|
||||
We need a better solution. Example: we tag which definitions are recursive when we create them. */
|
||||
bool is_recursive(name const & c) {
|
||||
constant_info info = env().get(c);
|
||||
return static_cast<bool>(::lean::find(info.get_value(), [&](expr const & e, unsigned) {
|
||||
return is_constant(e) && const_name(e) == c.get_prefix();
|
||||
}));
|
||||
}
|
||||
|
||||
expr try_inline(expr const & fn, expr const & e, bool is_let_val) {
|
||||
lean_assert(is_constant(fn));
|
||||
lean_assert(is_eqp(find(get_app_fn(e)), fn));
|
||||
if (!m_cfg.m_inline) return e;
|
||||
if (has_noinline_attribute(env(), const_name(fn))) return e;
|
||||
optional<constant_info> info = env().find(mk_cstage1_name(const_name(fn)));
|
||||
name c = mk_cstage1_name(const_name(fn));
|
||||
optional<constant_info> info = env().find(c);
|
||||
if (!info || !info->is_definition()) return e;
|
||||
if (get_app_num_args(e) < get_num_nested_lambdas(info->get_value())) return e;
|
||||
/* TODO(Leo): check size and whether function is boring or not. */
|
||||
if (!has_inline_attribute(env(), const_name(fn))) return e;
|
||||
if (!has_inline_attribute(env(), const_name(fn)) &&
|
||||
get_lcnf_size(env(), info->get_value()) > m_cfg.m_inline_threshold)
|
||||
return e;
|
||||
if (is_recursive(c)) return e;
|
||||
expr new_fn = instantiate_value_lparams(*info, const_levels(fn));
|
||||
return beta_reduce(new_fn, e, is_let_val);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,11 +8,19 @@ Author: Leonardo de Moura
|
|||
#include "kernel/environment.h"
|
||||
namespace lean {
|
||||
struct csimp_cfg {
|
||||
/* If `m_inline` == false, then we will not inline `c` even if it is marked with the attribute `[inline]`. */
|
||||
bool m_inline;
|
||||
/* We inline "cheap" functions. We say a function is cheap if `get_lcnf_size(val) < m_inline_threshold`,
|
||||
and it is not marked as `[noinline]`. */
|
||||
unsigned m_inline_threshold;
|
||||
/* Enable float cases_on from application. Remark: this transformation is essential for monadic code. */
|
||||
bool m_float_cases_app;
|
||||
/* Enable float cases_on from cases_on and other expressions. */
|
||||
bool m_float_cases;
|
||||
/* We only perform float cases_on from cases_on and other expression if the potential code blowup is smaller
|
||||
than m_float_cases_threshold. */
|
||||
unsigned m_float_cases_threshold;
|
||||
/* We inline join-points that are smaller m_inline_threshold. */
|
||||
unsigned m_inline_jp_threshold;
|
||||
public:
|
||||
csimp_cfg();
|
||||
|
|
|
|||
|
|
@ -358,10 +358,7 @@ unsigned get_lcnf_size(environment const & env, expr e) {
|
|||
case expr_kind::MData:
|
||||
return 1;
|
||||
case expr_kind::Const:
|
||||
if (is_constructor(env, const_name(e)))
|
||||
return 0;
|
||||
else
|
||||
return 1;
|
||||
return 1;
|
||||
case expr_kind::Lambda:
|
||||
r = 1;
|
||||
while (is_lambda(e)) {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue