feat(library/compiler/csimp): float all cases
Motivation: explicit control flow graph TODO: disabled `split_entries` for now. I believe the new feature exposed a bug at `move_to_entries`. I will fix this new issue in another commit.
This commit is contained in:
parent
4ba3cc390f
commit
7ec00c97e9
2 changed files with 38 additions and 91 deletions
|
|
@ -28,8 +28,6 @@ csimp_cfg::csimp_cfg(options const &):
|
|||
csimp_cfg::csimp_cfg() {
|
||||
m_inline = true;
|
||||
m_inline_threshold = 1;
|
||||
m_float_cases_app = true;
|
||||
m_float_cases = true;
|
||||
m_float_cases_threshold = 20;
|
||||
m_inline_jp_threshold = 2;
|
||||
}
|
||||
|
|
@ -309,30 +307,6 @@ class csimp_fn {
|
|||
});
|
||||
}
|
||||
|
||||
/* Return true iff the free variable `x` occurs in a projection or is the major premise of
|
||||
a `cases_on` application in `e`. */
|
||||
bool is_proj_or_cases_on_arg_at(expr const & x, expr const & e) {
|
||||
lean_assert(m_before_erasure);
|
||||
lean_assert(is_fvar(x));
|
||||
if (!has_fvar(e)) return false;
|
||||
bool found = false;
|
||||
for_each(e, [&](expr const & s, unsigned) {
|
||||
if (!has_fvar(s)) return false;
|
||||
if (found) return false;
|
||||
if (is_proj(s) && proj_expr(s) == x) {
|
||||
found = true;
|
||||
return false;
|
||||
} else if (is_cases_on_app(env(), s) &&
|
||||
get_app_num_args(s) == get_cases_on_arity(env(), const_name(get_app_fn(s)), m_before_erasure) &&
|
||||
get_cases_on_app_major(env(), s, m_before_erasure) == x) {
|
||||
found = true;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
/* Collect information for deciding whether `float_cases_on` is useful or not, and control
|
||||
code blowup. */
|
||||
struct cases_info_result {
|
||||
|
|
@ -383,31 +357,20 @@ class csimp_fn {
|
|||
This method creates one (or more) join-point(s) for `e` (if needed).
|
||||
Return `none` if the code size increase is above the threshold.
|
||||
Remark: it may produce type incorrect terms. */
|
||||
optional<expr> mk_join_point_float_cases_on(expr const & fvar, expr const & e, expr const & c) {
|
||||
lean_assert(m_before_erasure);
|
||||
expr mk_join_point_float_cases_on(expr const & fvar, expr const & e, expr const & c) {
|
||||
lean_assert(is_cases_on_app(env(), c));
|
||||
unsigned e_size = get_lcnf_size(env(), e);
|
||||
if (e_size == 1) {
|
||||
return some_expr(e);
|
||||
}
|
||||
if (!is_proj_or_cases_on_arg_at(fvar, e)) {
|
||||
/* It is not worthwhile to apply `float_cases_on` since `e` does not project or destruct the result produced
|
||||
by `c`. */
|
||||
return none_expr();
|
||||
return e;
|
||||
}
|
||||
cases_info_result c_info;
|
||||
collect_cases_info(c, c_info);
|
||||
if (c_info.m_num_cnstr_results == 0) {
|
||||
/* It is not worthwhile to apply `float_cases_on` since none of `c` branches return a constructor. */
|
||||
return none_expr();
|
||||
}
|
||||
lean_assert(c_info.m_num_branches > 0);
|
||||
lean_assert(c_info.m_num_cnstr_results <= c_info.m_num_branches);
|
||||
unsigned code_increase = e_size*(c_info.m_num_branches - 1);
|
||||
if (code_increase <= m_cfg.m_float_cases_threshold) {
|
||||
return some(e);
|
||||
} else if (is_cases_on_app(env(), e)) {
|
||||
local_decl fvar_decl = m_lctx.get_local_decl(fvar);
|
||||
return e;
|
||||
}
|
||||
local_decl fvar_decl = m_lctx.get_local_decl(fvar);
|
||||
if (is_cases_on_app(env(), e)) {
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(e, args);
|
||||
inductive_val e_I_val = get_cases_on_inductive_val(env(), fn);
|
||||
|
|
@ -420,7 +383,6 @@ class csimp_fn {
|
|||
unsigned new_code_increase = e_compressed_size*(c_info.m_num_branches - c_info.m_num_cnstr_results);
|
||||
if (new_code_increase <= m_cfg.m_float_cases_threshold) {
|
||||
unsigned branch_threshold = m_cfg.m_float_cases_threshold / (c_info.m_num_branches - 1);
|
||||
lean_assert(m_before_erasure);
|
||||
unsigned begin_minors; unsigned end_minors;
|
||||
std::tie(begin_minors, end_minors) = get_cases_on_minors_range(env(), const_name(fn), m_before_erasure);
|
||||
for (unsigned minor_idx = begin_minors; minor_idx < end_minors; minor_idx++) {
|
||||
|
|
@ -459,7 +421,6 @@ class csimp_fn {
|
|||
jp_val = mk_join_point_lambda(jp_args, jp_val);
|
||||
}
|
||||
/* Create new jp */
|
||||
lean_assert(m_before_erasure);
|
||||
expr jp_type = cheap_beta_reduce(infer_type(jp_val));
|
||||
mark_simplified(jp_val);
|
||||
expr jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_type, jp_val);
|
||||
|
|
@ -487,12 +448,19 @@ class csimp_fn {
|
|||
lean_trace(name({"compiler", "simp"}),
|
||||
tout() << "mk_join " << fvar << "\n" << c << "\n---\n"
|
||||
<< e << "\n======>\n" << mk_app(fn, args) << "\n";);
|
||||
return some_expr(mk_app(fn, args));
|
||||
return mk_app(fn, args);
|
||||
}
|
||||
}
|
||||
/* It is not worthwhile to create a join point for the whole `e` since we will not
|
||||
be able to perform any simplification. */
|
||||
return none_expr();
|
||||
/* Create simple join point */
|
||||
expr jp_val = e;
|
||||
if (is_lambda(e))
|
||||
jp_val = mk_trivial_let(jp_val);
|
||||
jp_val = ::lean::mk_lambda(fvar_decl.get_user_name(), fvar_decl.get_type(), abstract(jp_val, fvar));
|
||||
expr jp_type = cheap_beta_reduce(infer_type(jp_val));
|
||||
mark_simplified(jp_val);
|
||||
expr jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_type, jp_val);
|
||||
register_new_jp(jp_var);
|
||||
return mk_app(jp_var, fvar);
|
||||
}
|
||||
|
||||
/* Given `e[x]`, create a let-decl `y := v`, and return `e[y]`
|
||||
|
|
@ -648,14 +616,9 @@ class csimp_fn {
|
|||
|
||||
/* Float cases transformation (see: `float_cases_on_core`).
|
||||
This version may create join points if `e` is big, or "good" join-points could not be created. */
|
||||
optional<expr> float_cases_on(expr const & x, expr const & e, expr const & c) {
|
||||
lean_assert(m_before_erasure);
|
||||
unsigned saved_fvars_size = m_fvars.size();
|
||||
if (optional<expr> new_e = mk_join_point_float_cases_on(x, e, c)) {
|
||||
return some_expr(float_cases_on_core(x, *new_e, c));
|
||||
}
|
||||
m_fvars.shrink(saved_fvars_size);
|
||||
return none_expr();
|
||||
expr float_cases_on(expr const & x, expr const & e, expr const & c) {
|
||||
expr new_e = mk_join_point_float_cases_on(x, e, c);
|
||||
return float_cases_on_core(x, new_e, c);
|
||||
}
|
||||
|
||||
/* Given the buffer `entries`: `[(x_1, w_1), ..., (x_n, w_n)]`, and `e`.
|
||||
|
|
@ -888,30 +851,24 @@ class csimp_fn {
|
|||
}
|
||||
|
||||
if (is_cases_on_app(env(), val)) {
|
||||
/* Float cases transformation. */
|
||||
if (m_cfg.m_float_cases && m_before_erasure) {
|
||||
/* We first create a let-declaration with all entries that depends on the current
|
||||
`x` which is a cases_on application. */
|
||||
buffer<pair<expr, expr>> entries_dep_curr;
|
||||
buffer<pair<expr, expr>> entries_ndep_curr;
|
||||
split_entries(entries, x, entries_dep_curr, entries_ndep_curr);
|
||||
expr new_e = mk_let_core(entries_dep_curr, e);
|
||||
if (optional<expr> new_e_opt = float_cases_on(x, new_e, val)) {
|
||||
e = *new_e_opt;
|
||||
lean_assert(is_cases_on_app(env(), e));
|
||||
e_is_cases = true;
|
||||
/* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */
|
||||
e_fvars.clear(); entries_fvars.clear();
|
||||
collect_used(e, e_fvars);
|
||||
/* Join points may have been generated, we move them to entries. */
|
||||
entries.clear();
|
||||
/* Copy `entries_ndep_curr` to `entries` */
|
||||
move_to_entries(entries_ndep_curr, entries, entries_fvars);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
val = visit_cases_default(val);
|
||||
modified_val = true;
|
||||
/* We first create a let-declaration with all entries that depends on the current
|
||||
`x` which is a cases_on application. */
|
||||
buffer<pair<expr, expr>> entries_dep_curr;
|
||||
buffer<pair<expr, expr>> entries_ndep_curr;
|
||||
// split_entries(entries, x, entries_dep_curr, entries_ndep_curr);
|
||||
// expr new_e = mk_let_core(entries_dep_curr, e);
|
||||
expr new_e = mk_let_core(entries, e);
|
||||
e = float_cases_on(x, new_e, val);
|
||||
lean_assert(is_cases_on_app(env(), e));
|
||||
e_is_cases = true;
|
||||
/* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */
|
||||
e_fvars.clear(); entries_fvars.clear();
|
||||
collect_used(e, e_fvars);
|
||||
/* Join points may have been generated, we move them to entries. */
|
||||
entries.clear();
|
||||
/* Copy `entries_ndep_curr` to `entries` */
|
||||
move_to_entries(entries_ndep_curr, entries, entries_fvars);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!is_jp && e_is_cases && used_in_e) {
|
||||
|
|
@ -1435,12 +1392,6 @@ class csimp_fn {
|
|||
expr fn = find(get_app_fn(e));
|
||||
if (is_lambda(fn)) {
|
||||
return beta_reduce(fn, e, is_let_val);
|
||||
} else if (is_cases_on_app(env(), fn) && m_cfg.m_float_cases_app) {
|
||||
lean_assert(is_fvar(get_app_fn(e)));
|
||||
/* float cases_on from application */
|
||||
expr new_e = float_cases_on_core(get_app_fn(e), e, fn);
|
||||
mark_simplified(new_e);
|
||||
return new_e;
|
||||
} else if (is_lc_unreachable_app(fn)) {
|
||||
lean_assert(m_before_erasure);
|
||||
expr type = infer_type(e);
|
||||
|
|
|
|||
|
|
@ -13,10 +13,6 @@ struct csimp_cfg {
|
|||
/* We inline "cheap" functions. We say a function is cheap if `get_lcnf_size(val) < m_inline_threshold`,
|
||||
and it is not marked as `[noinline]`. */
|
||||
unsigned m_inline_threshold;
|
||||
/* Enable float cases_on from application. Remark: this transformation is essential for monadic code. */
|
||||
bool m_float_cases_app;
|
||||
/* Enable float cases_on from cases_on and other expressions. */
|
||||
bool m_float_cases;
|
||||
/* We only perform float cases_on from cases_on and other expression if the potential code blowup is smaller
|
||||
than m_float_cases_threshold. */
|
||||
unsigned m_float_cases_threshold;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue