feat(library/compiler/csimp): float all cases

Motivation: explicit control flow graph

TODO: disabled `split_entries` for now.
I believe the new feature exposed a bug at `move_to_entries`.
I will fix this new issue in another commit.
This commit is contained in:
Leonardo de Moura 2018-10-25 15:18:25 -07:00
parent 4ba3cc390f
commit 7ec00c97e9
2 changed files with 38 additions and 91 deletions

View file

@ -28,8 +28,6 @@ csimp_cfg::csimp_cfg(options const &):
csimp_cfg::csimp_cfg() {
m_inline = true;
m_inline_threshold = 1;
m_float_cases_app = true;
m_float_cases = true;
m_float_cases_threshold = 20;
m_inline_jp_threshold = 2;
}
@ -309,30 +307,6 @@ class csimp_fn {
});
}
/* Return true iff the free variable `x` occurs in a projection or is the major premise of
a `cases_on` application in `e`. */
bool is_proj_or_cases_on_arg_at(expr const & x, expr const & e) {
lean_assert(m_before_erasure);
lean_assert(is_fvar(x));
if (!has_fvar(e)) return false;
bool found = false;
for_each(e, [&](expr const & s, unsigned) {
if (!has_fvar(s)) return false;
if (found) return false;
if (is_proj(s) && proj_expr(s) == x) {
found = true;
return false;
} else if (is_cases_on_app(env(), s) &&
get_app_num_args(s) == get_cases_on_arity(env(), const_name(get_app_fn(s)), m_before_erasure) &&
get_cases_on_app_major(env(), s, m_before_erasure) == x) {
found = true;
return false;
}
return true;
});
return found;
}
/* Collect information for deciding whether `float_cases_on` is useful or not, and control
code blowup. */
struct cases_info_result {
@ -383,31 +357,20 @@ class csimp_fn {
This method creates one (or more) join-point(s) for `e` (if needed).
Return `none` if the code size increase is above the threshold.
Remark: it may produce type incorrect terms. */
optional<expr> mk_join_point_float_cases_on(expr const & fvar, expr const & e, expr const & c) {
lean_assert(m_before_erasure);
expr mk_join_point_float_cases_on(expr const & fvar, expr const & e, expr const & c) {
lean_assert(is_cases_on_app(env(), c));
unsigned e_size = get_lcnf_size(env(), e);
if (e_size == 1) {
return some_expr(e);
}
if (!is_proj_or_cases_on_arg_at(fvar, e)) {
/* It is not worthwhile to apply `float_cases_on` since `e` does not project or destruct the result produced
by `c`. */
return none_expr();
return e;
}
cases_info_result c_info;
collect_cases_info(c, c_info);
if (c_info.m_num_cnstr_results == 0) {
/* It is not worthwhile to apply `float_cases_on` since none of `c` branches return a constructor. */
return none_expr();
}
lean_assert(c_info.m_num_branches > 0);
lean_assert(c_info.m_num_cnstr_results <= c_info.m_num_branches);
unsigned code_increase = e_size*(c_info.m_num_branches - 1);
if (code_increase <= m_cfg.m_float_cases_threshold) {
return some(e);
} else if (is_cases_on_app(env(), e)) {
local_decl fvar_decl = m_lctx.get_local_decl(fvar);
return e;
}
local_decl fvar_decl = m_lctx.get_local_decl(fvar);
if (is_cases_on_app(env(), e)) {
buffer<expr> args;
expr const & fn = get_app_args(e, args);
inductive_val e_I_val = get_cases_on_inductive_val(env(), fn);
@ -420,7 +383,6 @@ class csimp_fn {
unsigned new_code_increase = e_compressed_size*(c_info.m_num_branches - c_info.m_num_cnstr_results);
if (new_code_increase <= m_cfg.m_float_cases_threshold) {
unsigned branch_threshold = m_cfg.m_float_cases_threshold / (c_info.m_num_branches - 1);
lean_assert(m_before_erasure);
unsigned begin_minors; unsigned end_minors;
std::tie(begin_minors, end_minors) = get_cases_on_minors_range(env(), const_name(fn), m_before_erasure);
for (unsigned minor_idx = begin_minors; minor_idx < end_minors; minor_idx++) {
@ -459,7 +421,6 @@ class csimp_fn {
jp_val = mk_join_point_lambda(jp_args, jp_val);
}
/* Create new jp */
lean_assert(m_before_erasure);
expr jp_type = cheap_beta_reduce(infer_type(jp_val));
mark_simplified(jp_val);
expr jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_type, jp_val);
@ -487,12 +448,19 @@ class csimp_fn {
lean_trace(name({"compiler", "simp"}),
tout() << "mk_join " << fvar << "\n" << c << "\n---\n"
<< e << "\n======>\n" << mk_app(fn, args) << "\n";);
return some_expr(mk_app(fn, args));
return mk_app(fn, args);
}
}
/* It is not worthwhile to create a join point for the whole `e` since we will not
be able to perform any simplification. */
return none_expr();
/* Create simple join point */
expr jp_val = e;
if (is_lambda(e))
jp_val = mk_trivial_let(jp_val);
jp_val = ::lean::mk_lambda(fvar_decl.get_user_name(), fvar_decl.get_type(), abstract(jp_val, fvar));
expr jp_type = cheap_beta_reduce(infer_type(jp_val));
mark_simplified(jp_val);
expr jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_type, jp_val);
register_new_jp(jp_var);
return mk_app(jp_var, fvar);
}
/* Given `e[x]`, create a let-decl `y := v`, and return `e[y]`
@ -648,14 +616,9 @@ class csimp_fn {
/* Float cases transformation (see: `float_cases_on_core`).
This version may create join points if `e` is big, or "good" join-points could not be created. */
optional<expr> float_cases_on(expr const & x, expr const & e, expr const & c) {
lean_assert(m_before_erasure);
unsigned saved_fvars_size = m_fvars.size();
if (optional<expr> new_e = mk_join_point_float_cases_on(x, e, c)) {
return some_expr(float_cases_on_core(x, *new_e, c));
}
m_fvars.shrink(saved_fvars_size);
return none_expr();
expr float_cases_on(expr const & x, expr const & e, expr const & c) {
expr new_e = mk_join_point_float_cases_on(x, e, c);
return float_cases_on_core(x, new_e, c);
}
/* Given the buffer `entries`: `[(x_1, w_1), ..., (x_n, w_n)]`, and `e`.
@ -888,30 +851,24 @@ class csimp_fn {
}
if (is_cases_on_app(env(), val)) {
/* Float cases transformation. */
if (m_cfg.m_float_cases && m_before_erasure) {
/* We first create a let-declaration with all entries that depends on the current
`x` which is a cases_on application. */
buffer<pair<expr, expr>> entries_dep_curr;
buffer<pair<expr, expr>> entries_ndep_curr;
split_entries(entries, x, entries_dep_curr, entries_ndep_curr);
expr new_e = mk_let_core(entries_dep_curr, e);
if (optional<expr> new_e_opt = float_cases_on(x, new_e, val)) {
e = *new_e_opt;
lean_assert(is_cases_on_app(env(), e));
e_is_cases = true;
/* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */
e_fvars.clear(); entries_fvars.clear();
collect_used(e, e_fvars);
/* Join points may have been generated, we move them to entries. */
entries.clear();
/* Copy `entries_ndep_curr` to `entries` */
move_to_entries(entries_ndep_curr, entries, entries_fvars);
continue;
}
}
val = visit_cases_default(val);
modified_val = true;
/* We first create a let-declaration with all entries that depends on the current
`x` which is a cases_on application. */
buffer<pair<expr, expr>> entries_dep_curr;
buffer<pair<expr, expr>> entries_ndep_curr;
// split_entries(entries, x, entries_dep_curr, entries_ndep_curr);
// expr new_e = mk_let_core(entries_dep_curr, e);
expr new_e = mk_let_core(entries, e);
e = float_cases_on(x, new_e, val);
lean_assert(is_cases_on_app(env(), e));
e_is_cases = true;
/* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */
e_fvars.clear(); entries_fvars.clear();
collect_used(e, e_fvars);
/* Join points may have been generated, we move them to entries. */
entries.clear();
/* Copy `entries_ndep_curr` to `entries` */
move_to_entries(entries_ndep_curr, entries, entries_fvars);
continue;
}
if (!is_jp && e_is_cases && used_in_e) {
@ -1435,12 +1392,6 @@ class csimp_fn {
expr fn = find(get_app_fn(e));
if (is_lambda(fn)) {
return beta_reduce(fn, e, is_let_val);
} else if (is_cases_on_app(env(), fn) && m_cfg.m_float_cases_app) {
lean_assert(is_fvar(get_app_fn(e)));
/* float cases_on from application */
expr new_e = float_cases_on_core(get_app_fn(e), e, fn);
mark_simplified(new_e);
return new_e;
} else if (is_lc_unreachable_app(fn)) {
lean_assert(m_before_erasure);
expr type = infer_type(e);

View file

@ -13,10 +13,6 @@ struct csimp_cfg {
/* We inline "cheap" functions. We say a function is cheap if `get_lcnf_size(val) < m_inline_threshold`,
and it is not marked as `[noinline]`. */
unsigned m_inline_threshold;
/* Enable float cases_on from application. Remark: this transformation is essential for monadic code. */
bool m_float_cases_app;
/* Enable float cases_on from cases_on and other expressions. */
bool m_float_cases;
/* We only perform float cases_on from cases_on and other expression if the potential code blowup is smaller
than m_float_cases_threshold. */
unsigned m_float_cases_threshold;