diff --git a/src/library/compiler/csimp.cpp b/src/library/compiler/csimp.cpp index 19a1fb7e7e..59bd68c6d5 100644 --- a/src/library/compiler/csimp.cpp +++ b/src/library/compiler/csimp.cpp @@ -28,8 +28,6 @@ csimp_cfg::csimp_cfg(options const &): csimp_cfg::csimp_cfg() { m_inline = true; m_inline_threshold = 1; - m_float_cases_app = true; - m_float_cases = true; m_float_cases_threshold = 20; m_inline_jp_threshold = 2; } @@ -309,30 +307,6 @@ class csimp_fn { }); } - /* Return true iff the free variable `x` occurs in a projection or is the major premise of - a `cases_on` application in `e`. */ - bool is_proj_or_cases_on_arg_at(expr const & x, expr const & e) { - lean_assert(m_before_erasure); - lean_assert(is_fvar(x)); - if (!has_fvar(e)) return false; - bool found = false; - for_each(e, [&](expr const & s, unsigned) { - if (!has_fvar(s)) return false; - if (found) return false; - if (is_proj(s) && proj_expr(s) == x) { - found = true; - return false; - } else if (is_cases_on_app(env(), s) && - get_app_num_args(s) == get_cases_on_arity(env(), const_name(get_app_fn(s)), m_before_erasure) && - get_cases_on_app_major(env(), s, m_before_erasure) == x) { - found = true; - return false; - } - return true; - }); - return found; - } - /* Collect information for deciding whether `float_cases_on` is useful or not, and control code blowup. */ struct cases_info_result { @@ -383,31 +357,20 @@ class csimp_fn { This method creates one (or more) join-point(s) for `e` (if needed). Return `none` if the code size increase is above the threshold. Remark: it may produce type incorrect terms. */ - optional mk_join_point_float_cases_on(expr const & fvar, expr const & e, expr const & c) { - lean_assert(m_before_erasure); + expr mk_join_point_float_cases_on(expr const & fvar, expr const & e, expr const & c) { lean_assert(is_cases_on_app(env(), c)); unsigned e_size = get_lcnf_size(env(), e); if (e_size == 1) { - return some_expr(e); - } - if (!is_proj_or_cases_on_arg_at(fvar, e)) { - /* It is not worthwhile to apply `float_cases_on` since `e` does not project or destruct the result produced - by `c`. */ - return none_expr(); + return e; } cases_info_result c_info; collect_cases_info(c, c_info); - if (c_info.m_num_cnstr_results == 0) { - /* It is not worthwhile to apply `float_cases_on` since none of `c` branches return a constructor. */ - return none_expr(); - } - lean_assert(c_info.m_num_branches > 0); - lean_assert(c_info.m_num_cnstr_results <= c_info.m_num_branches); unsigned code_increase = e_size*(c_info.m_num_branches - 1); if (code_increase <= m_cfg.m_float_cases_threshold) { - return some(e); - } else if (is_cases_on_app(env(), e)) { - local_decl fvar_decl = m_lctx.get_local_decl(fvar); + return e; + } + local_decl fvar_decl = m_lctx.get_local_decl(fvar); + if (is_cases_on_app(env(), e)) { buffer args; expr const & fn = get_app_args(e, args); inductive_val e_I_val = get_cases_on_inductive_val(env(), fn); @@ -420,7 +383,6 @@ class csimp_fn { unsigned new_code_increase = e_compressed_size*(c_info.m_num_branches - c_info.m_num_cnstr_results); if (new_code_increase <= m_cfg.m_float_cases_threshold) { unsigned branch_threshold = m_cfg.m_float_cases_threshold / (c_info.m_num_branches - 1); - lean_assert(m_before_erasure); unsigned begin_minors; unsigned end_minors; std::tie(begin_minors, end_minors) = get_cases_on_minors_range(env(), const_name(fn), m_before_erasure); for (unsigned minor_idx = begin_minors; minor_idx < end_minors; minor_idx++) { @@ -459,7 +421,6 @@ class csimp_fn { jp_val = mk_join_point_lambda(jp_args, jp_val); } /* Create new jp */ - lean_assert(m_before_erasure); expr jp_type = cheap_beta_reduce(infer_type(jp_val)); mark_simplified(jp_val); expr jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_type, jp_val); @@ -487,12 +448,19 @@ class csimp_fn { lean_trace(name({"compiler", "simp"}), tout() << "mk_join " << fvar << "\n" << c << "\n---\n" << e << "\n======>\n" << mk_app(fn, args) << "\n";); - return some_expr(mk_app(fn, args)); + return mk_app(fn, args); } } - /* It is not worthwhile to create a join point for the whole `e` since we will not - be able to perform any simplification. */ - return none_expr(); + /* Create simple join point */ + expr jp_val = e; + if (is_lambda(e)) + jp_val = mk_trivial_let(jp_val); + jp_val = ::lean::mk_lambda(fvar_decl.get_user_name(), fvar_decl.get_type(), abstract(jp_val, fvar)); + expr jp_type = cheap_beta_reduce(infer_type(jp_val)); + mark_simplified(jp_val); + expr jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_type, jp_val); + register_new_jp(jp_var); + return mk_app(jp_var, fvar); } /* Given `e[x]`, create a let-decl `y := v`, and return `e[y]` @@ -648,14 +616,9 @@ class csimp_fn { /* Float cases transformation (see: `float_cases_on_core`). This version may create join points if `e` is big, or "good" join-points could not be created. */ - optional float_cases_on(expr const & x, expr const & e, expr const & c) { - lean_assert(m_before_erasure); - unsigned saved_fvars_size = m_fvars.size(); - if (optional new_e = mk_join_point_float_cases_on(x, e, c)) { - return some_expr(float_cases_on_core(x, *new_e, c)); - } - m_fvars.shrink(saved_fvars_size); - return none_expr(); + expr float_cases_on(expr const & x, expr const & e, expr const & c) { + expr new_e = mk_join_point_float_cases_on(x, e, c); + return float_cases_on_core(x, new_e, c); } /* Given the buffer `entries`: `[(x_1, w_1), ..., (x_n, w_n)]`, and `e`. @@ -888,30 +851,24 @@ class csimp_fn { } if (is_cases_on_app(env(), val)) { - /* Float cases transformation. */ - if (m_cfg.m_float_cases && m_before_erasure) { - /* We first create a let-declaration with all entries that depends on the current - `x` which is a cases_on application. */ - buffer> entries_dep_curr; - buffer> entries_ndep_curr; - split_entries(entries, x, entries_dep_curr, entries_ndep_curr); - expr new_e = mk_let_core(entries_dep_curr, e); - if (optional new_e_opt = float_cases_on(x, new_e, val)) { - e = *new_e_opt; - lean_assert(is_cases_on_app(env(), e)); - e_is_cases = true; - /* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */ - e_fvars.clear(); entries_fvars.clear(); - collect_used(e, e_fvars); - /* Join points may have been generated, we move them to entries. */ - entries.clear(); - /* Copy `entries_ndep_curr` to `entries` */ - move_to_entries(entries_ndep_curr, entries, entries_fvars); - continue; - } - } - val = visit_cases_default(val); - modified_val = true; + /* We first create a let-declaration with all entries that depends on the current + `x` which is a cases_on application. */ + buffer> entries_dep_curr; + buffer> entries_ndep_curr; + // split_entries(entries, x, entries_dep_curr, entries_ndep_curr); + // expr new_e = mk_let_core(entries_dep_curr, e); + expr new_e = mk_let_core(entries, e); + e = float_cases_on(x, new_e, val); + lean_assert(is_cases_on_app(env(), e)); + e_is_cases = true; + /* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */ + e_fvars.clear(); entries_fvars.clear(); + collect_used(e, e_fvars); + /* Join points may have been generated, we move them to entries. */ + entries.clear(); + /* Copy `entries_ndep_curr` to `entries` */ + move_to_entries(entries_ndep_curr, entries, entries_fvars); + continue; } if (!is_jp && e_is_cases && used_in_e) { @@ -1435,12 +1392,6 @@ class csimp_fn { expr fn = find(get_app_fn(e)); if (is_lambda(fn)) { return beta_reduce(fn, e, is_let_val); - } else if (is_cases_on_app(env(), fn) && m_cfg.m_float_cases_app) { - lean_assert(is_fvar(get_app_fn(e))); - /* float cases_on from application */ - expr new_e = float_cases_on_core(get_app_fn(e), e, fn); - mark_simplified(new_e); - return new_e; } else if (is_lc_unreachable_app(fn)) { lean_assert(m_before_erasure); expr type = infer_type(e); diff --git a/src/library/compiler/csimp.h b/src/library/compiler/csimp.h index 79202ea659..db3746a770 100644 --- a/src/library/compiler/csimp.h +++ b/src/library/compiler/csimp.h @@ -13,10 +13,6 @@ struct csimp_cfg { /* We inline "cheap" functions. We say a function is cheap if `get_lcnf_size(val) < m_inline_threshold`, and it is not marked as `[noinline]`. */ unsigned m_inline_threshold; - /* Enable float cases_on from application. Remark: this transformation is essential for monadic code. */ - bool m_float_cases_app; - /* Enable float cases_on from cases_on and other expressions. */ - bool m_float_cases; /* We only perform float cases_on from cases_on and other expression if the potential code blowup is smaller than m_float_cases_threshold. */ unsigned m_float_cases_threshold;