diff --git a/src/library/compiler/csimp.cpp b/src/library/compiler/csimp.cpp index 856a142bc0..ac19a9b093 100644 --- a/src/library/compiler/csimp.cpp +++ b/src/library/compiler/csimp.cpp @@ -23,16 +23,16 @@ csimp_cfg::csimp_cfg() { m_inline = true; m_inline_threshold = 4; m_float_cases_app = true; - m_float_cases = false; - m_float_cases_jp_threshold = 3; - m_float_cases_jp_branch_threshold = 3; - m_inline_jp_threshold = 4; + m_float_cases = true; + m_float_cases_jp_threshold = 4; + m_float_cases_jp_branch_threshold = 2; + m_inline_jp_threshold = 2; } class csimp_fn { type_checker::state m_st; local_ctx m_lctx; - csimp_cfg const & m_cfg; + csimp_cfg m_cfg; buffer m_fvars; name m_x; name m_j; @@ -127,16 +127,6 @@ class csimp_fn { return fvar; } - /* Given the `cases_on` application, return [first_minor_idx, first_minor_idx + nminors) */ - pair get_cases_on_minors_range(name const & cases) { - inductive_val I_val = env().get(cases.get_prefix()).to_inductive_val(); - unsigned nparams = I_val.get_nparams(); - unsigned nindices = I_val.get_nindices(); - unsigned nminors = I_val.get_ncnstrs(); - unsigned first_minor_idx = nparams + 1 /*motive*/ + nindices + 1 /* major */; - return mk_pair(first_minor_idx, first_minor_idx + nminors); - } - /* Given a cases_on application `c`, return `some idx` iff `fvar` only occurs in the argument `idx`, this argument is a minor premise. */ optional used_in_one_minor(expr const & c, expr const & fvar) { @@ -145,7 +135,7 @@ class csimp_fn { buffer args; expr const & c_fn = get_app_args(c, args); unsigned minors_begin; unsigned minors_end; - std::tie(minors_begin, minors_end) = get_cases_on_minors_range(const_name(c_fn)); + std::tie(minors_begin, minors_end) = get_cases_on_minors_range(env(), const_name(c_fn)); unsigned i = 0; for (; i < minors_begin; i++) { if (has_fvar(args[i], fvar)) { @@ -237,6 +227,77 @@ class csimp_fn { }); } + /* Return true iff the free variable `x` occurs in a projection or is the major premise of + a `cases_on` application in `e`. */ + bool is_proj_or_cases_on_arg_at(expr const & x, expr const & e) { + lean_assert(is_fvar(x)); + if (!has_fvar(e)) return false; + bool found = false; + for_each(e, [&](expr const & s, unsigned) { + if (!has_fvar(s)) return false; + if (found) return false; + if (is_proj(s) && proj_expr(s) == x) { + found = true; + return false; + } else if (is_cases_on_app(env(), s) && get_cases_on_app_major(env(), s) == x) { + found = true; + return false; + } + return true; + }); + return found; + } + + /* Auxiliary method for `may_return_constructor`. `visited` contains the set of join points that + have already been visited. */ + bool may_return_constructor_core(expr e, name_set & visited) { + while (true) { + if (is_lambda(e)) + e = binding_body(e); + else if (is_let(e)) + e = let_body(e); + else + break; + } + if (is_constructor_app(env(), e)) { + return true; + } else if (is_cases_on_app(env(), e)) { + buffer args; + expr const & fn = get_app_args(e, args); + unsigned begin_minors; unsigned end_minors; + std::tie(begin_minors, end_minors) = get_cases_on_minors_range(env(), const_name(fn)); + for (unsigned i = begin_minors; i < end_minors; i++) { + if (may_return_constructor_core(args[i], visited)) + return true; + } + return false; + } else if (is_join_point_app(e)) { + expr const & fn = get_app_fn(e); + lean_assert(is_fvar(fn)); + if (visited.find(fvar_name(fn)) != visited.end()) + return false; + visited.insert(fvar_name(fn)); + local_decl decl = m_lctx.get_local_decl(fn); + return may_return_constructor_core(*decl.get_value(), visited); + } else { + return false; + } + } + + /* Return true if `e` may return a constructor application. We say "may" because the + result may be a `cases_on`-application and we return true if one of the branches return a constructor. */ + bool may_return_constructor(expr const & e) { + name_set visited; + return may_return_constructor_core(e, visited); + } + + bool is_float_cases_on_worthwhile(expr const & x, expr const & e, expr const & c) { + lean_assert(is_cases_on_app(env(), c)); + return + is_proj_or_cases_on_arg_at(x, e) && + may_return_constructor(c); + } + /* The `float_cases_on` transformation may produce code duplication. The term `e` is "copied" in each branch of the the `cases_on` expression `c`. This method creates one (or more) join-point(s) for `e` (if needed). @@ -244,7 +305,7 @@ class csimp_fn { optional mk_join_point_float_cases_on(expr const & fvar, expr const & e, expr const & c, buffer & new_jps) { lean_assert(is_cases_on_app(env(), c)); expr const & c_fn = get_app_fn(c); - inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val(); + inductive_val I_val = get_cases_on_inductive_val(env(), c_fn); if (I_val.get_ncnstrs() == 1) { /* `c` has only one case. So, only one copy of `e` may be created. */ return some_expr(e); @@ -258,7 +319,7 @@ class csimp_fn { bool modified = false; unsigned saved_fvars_size = m_fvars.size(); unsigned begin_minors; unsigned end_minors; - std::tie(begin_minors, end_minors) = get_cases_on_minors_range(const_name(fn)); + std::tie(begin_minors, end_minors) = get_cases_on_minors_range(env(), const_name(fn)); for (unsigned minor_idx = begin_minors; minor_idx < end_minors; minor_idx++) { expr minor = args[minor_idx]; if (get_lcnf_size(env(), minor) > m_cfg.m_float_cases_jp_branch_threshold) { @@ -467,7 +528,7 @@ class csimp_fn { expr result_type = whnf_infer_type(e); buffer c_args; expr c_fn = get_app_args(c, c_args); - inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val(); + inductive_val I_val = get_cases_on_inductive_val(env(), c_fn); unsigned motive_idx = I_val.get_nparams(); unsigned first_index = motive_idx + 1; unsigned nindices = I_val.get_nindices(); @@ -549,8 +610,7 @@ class csimp_fn { x_1 := w_1 in e ``` - The values `w_i` are the "simplified values" for the let-declaration `x_i`. - */ + The values `w_i` are the "simplified values" for the let-declaration `x_i`. */ expr mk_let(buffer> const & entries, expr e) { buffer fvars; buffer user_names; @@ -604,8 +664,7 @@ class csimp_fn { ... x_1 := w_1 in e - ``` - */ + ``` */ void split_entries(buffer> const & entries, expr const & x, buffer> & entries_dep_x, @@ -652,17 +711,17 @@ class csimp_fn { collect_used(e, e_fvars); bool e_is_cases = is_cases_on_app(env(), e); while (m_fvars.size() > saved_fvars_size) { - expr fvar = m_fvars.back(); + expr x = m_fvars.back(); m_fvars.pop_back(); - bool used_in_e = (e_fvars.find(fvar_name(fvar)) != e_fvars.end()); - bool used_in_entries = (entries_fvars.find(fvar_name(fvar)) != entries_fvars.end()); + bool used_in_e = (e_fvars.find(fvar_name(x)) != e_fvars.end()); + bool used_in_entries = (entries_fvars.find(fvar_name(x)) != entries_fvars.end()); if (!used_in_e && !used_in_entries) { /* Skip unused variables */ continue; } - local_decl decl = m_lctx.get_local_decl(fvar); - expr type = decl.get_type(); - expr val = *decl.get_value(); + local_decl x_decl = m_lctx.get_local_decl(x); + expr type = x_decl.get_type(); + expr val = *x_decl.get_value(); bool modified_val = false; if (is_lambda(val)) { /* We don't simplify lambdas when we visit `let`-expressions. */ @@ -672,7 +731,7 @@ class csimp_fn { lean_assert(m_fvars.size() == saved_fvars_size); } - if (is_fvar(e) && entries.empty() && fvar_name(e) == fvar_name(fvar)) { + if (entries.empty() && e == x) { /* `let x := v in x` ==> `v` */ e = val; collect_used(val, e_fvars); @@ -684,36 +743,38 @@ class csimp_fn { /* Float cases transformation. */ if (m_cfg.m_float_cases) { /* We first create a let-declaration with all entries that depends on the current - `fvar` which is a cases_on application. */ + `x` which is a cases_on application. */ buffer> entries_dep_curr; buffer> entries_ndep_curr; - split_entries(entries, fvar, entries_dep_curr, entries_ndep_curr); + split_entries(entries, x, entries_dep_curr, entries_ndep_curr); expr new_e = mk_let(entries_dep_curr, e); - buffer new_jps; - if (optional new_e_opt = float_cases_on(fvar, new_e, val, new_jps)) { - e = *new_e_opt; - /* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */ - e_fvars.clear(); entries_fvars.clear(); - collect_used(e, e_fvars); - /* Join points may have been generated, we move them to entries. */ - entries.clear(); - while (!new_jps.empty()) { - expr jp_fvar = new_jps.back(); - new_jps.pop_back(); - local_decl jp_decl = m_lctx.get_local_decl(jp_fvar); - expr jp_type = jp_decl.get_type(); - expr jp_val = *jp_decl.get_value(); - collect_used(jp_type, entries_fvars); - collect_used(jp_val, entries_fvars); - entries.emplace_back(jp_fvar, jp_val); + if (is_float_cases_on_worthwhile(x, new_e, val)) { + buffer new_jps; + if (optional new_e_opt = float_cases_on(x, new_e, val, new_jps)) { + e = *new_e_opt; + /* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */ + e_fvars.clear(); entries_fvars.clear(); + collect_used(e, e_fvars); + /* Join points may have been generated, we move them to entries. */ + entries.clear(); + while (!new_jps.empty()) { + expr jp_fvar = new_jps.back(); + new_jps.pop_back(); + local_decl jp_decl = m_lctx.get_local_decl(jp_fvar); + expr jp_type = jp_decl.get_type(); + expr jp_val = *jp_decl.get_value(); + collect_used(jp_type, entries_fvars); + collect_used(jp_val, entries_fvars); + entries.emplace_back(jp_fvar, jp_val); + } + /* Copy `entries_ndep_curr` to `entries` */ + for (unsigned i = 0; i < entries_ndep_curr.size(); i++) { + pair const & ndep_entry = entries_ndep_curr[i]; + entries.push_back(ndep_entry); + collect_used(ndep_entry.second, entries_fvars); + } + continue; } - /* Copy `entries_ndep_curr` to `entries` */ - for (unsigned i = 0; i < entries_ndep_curr.size(); i++) { - pair const & ndep_entry = entries_ndep_curr[i]; - entries.push_back(ndep_entry); - collect_used(ndep_entry.second, entries_fvars); - } - continue; } } val = visit_cases_default(val); @@ -721,26 +782,26 @@ class csimp_fn { } if (e_is_cases && used_in_e) { - optional minor_idx = used_in_one_minor(e, fvar); + optional minor_idx = used_in_one_minor(e, x); if (minor_idx && !used_in_entries) { - /* If fvar is only used in only one minor declaration, + /* If x is only used in only one minor declaration, and is *not* used in any expression at entries */ if (modified_val) { /* We need to create a new free variable since the new simplified value `val` */ - expr new_fvar = m_lctx.mk_local_decl(ngen(), decl.get_user_name(), type, val); - e = replace_fvar(e, fvar, new_fvar); - fvar = new_fvar; + expr new_x = m_lctx.mk_local_decl(ngen(), x_decl.get_user_name(), type, val); + e = replace_fvar(e, x, new_x); + x = new_x; } collect_used(type, e_fvars); collect_used(val, e_fvars); - e = move_let_to_minor(e, *minor_idx, fvar); + e = move_let_to_minor(e, *minor_idx, x); continue; } } collect_used(type, entries_fvars); collect_used(val, entries_fvars); - entries.emplace_back(fvar, val); + entries.emplace_back(x, val); } return mk_let(entries, e); } @@ -892,7 +953,7 @@ class csimp_fn { expr const & c = get_app_args(e, args); /* simplify minor premises */ unsigned minor_idx; unsigned minors_end; - std::tie(minor_idx, minors_end) = get_cases_on_minors_range(const_name(c)); + std::tie(minor_idx, minors_end) = get_cases_on_minors_range(env(), const_name(c)); for (; minor_idx < minors_end; minor_idx++) { expr minor = args[minor_idx]; unsigned saved_fvars_size = m_fvars.size(); @@ -933,7 +994,7 @@ class csimp_fn { buffer args; expr const & c = get_app_args(e, args); lean_assert(is_constant(c)); - inductive_val I_val = env().get(const_name(c).get_prefix()).to_inductive_val(); + inductive_val I_val = get_cases_on_inductive_val(env(), c); unsigned major_idx = I_val.get_nparams() + 1 /* typeformer/motive */ + I_val.get_nindices(); lean_assert(major_idx < args.size()); expr const & major = find(args[major_idx]); @@ -994,8 +1055,7 @@ class csimp_fn { b_1 := lc_cast a_1 ... b_n := lc_cast a_n - e := g b_1 ... b_n - */ + e := g b_1 ... b_n */ expr const & g = app_arg(fn); buffer args; get_app_args(e, args); diff --git a/src/library/compiler/util.cpp b/src/library/compiler/util.cpp index 478c91b6c7..a51699bf76 100644 --- a/src/library/compiler/util.cpp +++ b/src/library/compiler/util.cpp @@ -125,6 +125,23 @@ bool is_cases_on_recursor(environment const & env, name const & n) { return ::lean::is_aux_recursor(env, n) && n.get_string() == "cases_on"; } +expr get_cases_on_app_major(environment const & env, expr const & c) { + lean_assert(is_cases_on_app(env, c)); + buffer args; + expr const & fn = get_app_args(c, args); + inductive_val I_val = get_cases_on_inductive_val(env, fn); + return args[I_val.get_nparams() + 1 /* motive */ + I_val.get_nindices()]; +} + +pair get_cases_on_minors_range(environment const & env, name const & c) { + inductive_val I_val = get_cases_on_inductive_val(env, c); + unsigned nparams = I_val.get_nparams(); + unsigned nindices = I_val.get_nindices(); + unsigned nminors = I_val.get_ncnstrs(); + unsigned first_minor_idx = nparams + 1 /*motive*/ + nindices + 1 /* major */; + return mk_pair(first_minor_idx, first_minor_idx + nminors); +} + expr mk_lc_unreachable(type_checker::state & s, local_ctx const & lctx, expr const & type) { type_checker tc(s, lctx); level lvl = sort_level(tc.ensure_type(type)); diff --git a/src/library/compiler/util.h b/src/library/compiler/util.h index 205ab8e7a8..649687e825 100644 --- a/src/library/compiler/util.h +++ b/src/library/compiler/util.h @@ -30,10 +30,26 @@ expr unfold_macro_defs(environment const & env, expr const & e); inline bool is_lc_mdata(expr const &) { return false; } bool is_cases_on_recursor(environment const & env, name const & n); +/* Return the `inductive_val` for the cases_on constant `c`. */ +inline inductive_val get_cases_on_inductive_val(environment const & env, name const & c) { + lean_assert(is_cases_on_recursor(env, c)); + return env.get(c.get_prefix()).to_inductive_val(); +} +inline inductive_val get_cases_on_inductive_val(environment const & env, expr const & c) { + lean_assert(is_constant(c)); + return get_cases_on_inductive_val(env, const_name(c)); +} inline bool is_cases_on_app(environment const & env, expr const & e) { expr const & fn = get_app_fn(e); return is_constant(fn) && is_cases_on_recursor(env, const_name(fn)); } +/* Return the major premise of a cases_on-application. + \pre is_cases_on_app(env, c) */ +expr get_cases_on_app_major(environment const & env, expr const & c); +/* Return the pair `(b, e)` such that `i in [b, e)` is argument `i` in a `c` cases_on + application is a minor premise. + \pre is_cases_on_recursor(env, c) */ +pair get_cases_on_minors_range(environment const & env, name const & c); inline bool is_lc_unreachable_app(expr const & e) { return is_app_of(e, get_lc_unreachable_name(), 1); } inline bool is_lc_proof_app(expr const & e) { return is_app_of(e, get_lc_proof_name(), 1); }