feat(library/compiler): add is_float_cases_on_worthwhile predicate and cleanup
This commit is contained in:
parent
78de3de764
commit
d880b1c640
3 changed files with 160 additions and 67 deletions
|
|
@ -23,16 +23,16 @@ csimp_cfg::csimp_cfg() {
|
|||
m_inline = true;
|
||||
m_inline_threshold = 4;
|
||||
m_float_cases_app = true;
|
||||
m_float_cases = false;
|
||||
m_float_cases_jp_threshold = 3;
|
||||
m_float_cases_jp_branch_threshold = 3;
|
||||
m_inline_jp_threshold = 4;
|
||||
m_float_cases = true;
|
||||
m_float_cases_jp_threshold = 4;
|
||||
m_float_cases_jp_branch_threshold = 2;
|
||||
m_inline_jp_threshold = 2;
|
||||
}
|
||||
|
||||
class csimp_fn {
|
||||
type_checker::state m_st;
|
||||
local_ctx m_lctx;
|
||||
csimp_cfg const & m_cfg;
|
||||
csimp_cfg m_cfg;
|
||||
buffer<expr> m_fvars;
|
||||
name m_x;
|
||||
name m_j;
|
||||
|
|
@ -127,16 +127,6 @@ class csimp_fn {
|
|||
return fvar;
|
||||
}
|
||||
|
||||
/* Given the `cases_on` application, return [first_minor_idx, first_minor_idx + nminors) */
|
||||
pair<unsigned, unsigned> get_cases_on_minors_range(name const & cases) {
|
||||
inductive_val I_val = env().get(cases.get_prefix()).to_inductive_val();
|
||||
unsigned nparams = I_val.get_nparams();
|
||||
unsigned nindices = I_val.get_nindices();
|
||||
unsigned nminors = I_val.get_ncnstrs();
|
||||
unsigned first_minor_idx = nparams + 1 /*motive*/ + nindices + 1 /* major */;
|
||||
return mk_pair(first_minor_idx, first_minor_idx + nminors);
|
||||
}
|
||||
|
||||
/* Given a cases_on application `c`, return `some idx` iff `fvar` only occurs
|
||||
in the argument `idx`, this argument is a minor premise. */
|
||||
optional<unsigned> used_in_one_minor(expr const & c, expr const & fvar) {
|
||||
|
|
@ -145,7 +135,7 @@ class csimp_fn {
|
|||
buffer<expr> args;
|
||||
expr const & c_fn = get_app_args(c, args);
|
||||
unsigned minors_begin; unsigned minors_end;
|
||||
std::tie(minors_begin, minors_end) = get_cases_on_minors_range(const_name(c_fn));
|
||||
std::tie(minors_begin, minors_end) = get_cases_on_minors_range(env(), const_name(c_fn));
|
||||
unsigned i = 0;
|
||||
for (; i < minors_begin; i++) {
|
||||
if (has_fvar(args[i], fvar)) {
|
||||
|
|
@ -237,6 +227,77 @@ class csimp_fn {
|
|||
});
|
||||
}
|
||||
|
||||
/* Return true iff the free variable `x` occurs in a projection or is the major premise of
|
||||
a `cases_on` application in `e`. */
|
||||
bool is_proj_or_cases_on_arg_at(expr const & x, expr const & e) {
|
||||
lean_assert(is_fvar(x));
|
||||
if (!has_fvar(e)) return false;
|
||||
bool found = false;
|
||||
for_each(e, [&](expr const & s, unsigned) {
|
||||
if (!has_fvar(s)) return false;
|
||||
if (found) return false;
|
||||
if (is_proj(s) && proj_expr(s) == x) {
|
||||
found = true;
|
||||
return false;
|
||||
} else if (is_cases_on_app(env(), s) && get_cases_on_app_major(env(), s) == x) {
|
||||
found = true;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
/* Auxiliary method for `may_return_constructor`. `visited` contains the set of join points that
|
||||
have already been visited. */
|
||||
bool may_return_constructor_core(expr e, name_set & visited) {
|
||||
while (true) {
|
||||
if (is_lambda(e))
|
||||
e = binding_body(e);
|
||||
else if (is_let(e))
|
||||
e = let_body(e);
|
||||
else
|
||||
break;
|
||||
}
|
||||
if (is_constructor_app(env(), e)) {
|
||||
return true;
|
||||
} else if (is_cases_on_app(env(), e)) {
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(e, args);
|
||||
unsigned begin_minors; unsigned end_minors;
|
||||
std::tie(begin_minors, end_minors) = get_cases_on_minors_range(env(), const_name(fn));
|
||||
for (unsigned i = begin_minors; i < end_minors; i++) {
|
||||
if (may_return_constructor_core(args[i], visited))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} else if (is_join_point_app(e)) {
|
||||
expr const & fn = get_app_fn(e);
|
||||
lean_assert(is_fvar(fn));
|
||||
if (visited.find(fvar_name(fn)) != visited.end())
|
||||
return false;
|
||||
visited.insert(fvar_name(fn));
|
||||
local_decl decl = m_lctx.get_local_decl(fn);
|
||||
return may_return_constructor_core(*decl.get_value(), visited);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/* Return true if `e` may return a constructor application. We say "may" because the
|
||||
result may be a `cases_on`-application and we return true if one of the branches return a constructor. */
|
||||
bool may_return_constructor(expr const & e) {
|
||||
name_set visited;
|
||||
return may_return_constructor_core(e, visited);
|
||||
}
|
||||
|
||||
bool is_float_cases_on_worthwhile(expr const & x, expr const & e, expr const & c) {
|
||||
lean_assert(is_cases_on_app(env(), c));
|
||||
return
|
||||
is_proj_or_cases_on_arg_at(x, e) &&
|
||||
may_return_constructor(c);
|
||||
}
|
||||
|
||||
/* The `float_cases_on` transformation may produce code duplication.
|
||||
The term `e` is "copied" in each branch of the the `cases_on` expression `c`.
|
||||
This method creates one (or more) join-point(s) for `e` (if needed).
|
||||
|
|
@ -244,7 +305,7 @@ class csimp_fn {
|
|||
optional<expr> mk_join_point_float_cases_on(expr const & fvar, expr const & e, expr const & c, buffer<expr> & new_jps) {
|
||||
lean_assert(is_cases_on_app(env(), c));
|
||||
expr const & c_fn = get_app_fn(c);
|
||||
inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val();
|
||||
inductive_val I_val = get_cases_on_inductive_val(env(), c_fn);
|
||||
if (I_val.get_ncnstrs() == 1) {
|
||||
/* `c` has only one case. So, only one copy of `e` may be created. */
|
||||
return some_expr(e);
|
||||
|
|
@ -258,7 +319,7 @@ class csimp_fn {
|
|||
bool modified = false;
|
||||
unsigned saved_fvars_size = m_fvars.size();
|
||||
unsigned begin_minors; unsigned end_minors;
|
||||
std::tie(begin_minors, end_minors) = get_cases_on_minors_range(const_name(fn));
|
||||
std::tie(begin_minors, end_minors) = get_cases_on_minors_range(env(), const_name(fn));
|
||||
for (unsigned minor_idx = begin_minors; minor_idx < end_minors; minor_idx++) {
|
||||
expr minor = args[minor_idx];
|
||||
if (get_lcnf_size(env(), minor) > m_cfg.m_float_cases_jp_branch_threshold) {
|
||||
|
|
@ -467,7 +528,7 @@ class csimp_fn {
|
|||
expr result_type = whnf_infer_type(e);
|
||||
buffer<expr> c_args;
|
||||
expr c_fn = get_app_args(c, c_args);
|
||||
inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val();
|
||||
inductive_val I_val = get_cases_on_inductive_val(env(), c_fn);
|
||||
unsigned motive_idx = I_val.get_nparams();
|
||||
unsigned first_index = motive_idx + 1;
|
||||
unsigned nindices = I_val.get_nindices();
|
||||
|
|
@ -549,8 +610,7 @@ class csimp_fn {
|
|||
x_1 := w_1
|
||||
in e
|
||||
```
|
||||
The values `w_i` are the "simplified values" for the let-declaration `x_i`.
|
||||
*/
|
||||
The values `w_i` are the "simplified values" for the let-declaration `x_i`. */
|
||||
expr mk_let(buffer<pair<expr, expr>> const & entries, expr e) {
|
||||
buffer<expr> fvars;
|
||||
buffer<name> user_names;
|
||||
|
|
@ -604,8 +664,7 @@ class csimp_fn {
|
|||
...
|
||||
x_1 := w_1
|
||||
in e
|
||||
```
|
||||
*/
|
||||
``` */
|
||||
void split_entries(buffer<pair<expr, expr>> const & entries,
|
||||
expr const & x,
|
||||
buffer<pair<expr, expr>> & entries_dep_x,
|
||||
|
|
@ -652,17 +711,17 @@ class csimp_fn {
|
|||
collect_used(e, e_fvars);
|
||||
bool e_is_cases = is_cases_on_app(env(), e);
|
||||
while (m_fvars.size() > saved_fvars_size) {
|
||||
expr fvar = m_fvars.back();
|
||||
expr x = m_fvars.back();
|
||||
m_fvars.pop_back();
|
||||
bool used_in_e = (e_fvars.find(fvar_name(fvar)) != e_fvars.end());
|
||||
bool used_in_entries = (entries_fvars.find(fvar_name(fvar)) != entries_fvars.end());
|
||||
bool used_in_e = (e_fvars.find(fvar_name(x)) != e_fvars.end());
|
||||
bool used_in_entries = (entries_fvars.find(fvar_name(x)) != entries_fvars.end());
|
||||
if (!used_in_e && !used_in_entries) {
|
||||
/* Skip unused variables */
|
||||
continue;
|
||||
}
|
||||
local_decl decl = m_lctx.get_local_decl(fvar);
|
||||
expr type = decl.get_type();
|
||||
expr val = *decl.get_value();
|
||||
local_decl x_decl = m_lctx.get_local_decl(x);
|
||||
expr type = x_decl.get_type();
|
||||
expr val = *x_decl.get_value();
|
||||
bool modified_val = false;
|
||||
if (is_lambda(val)) {
|
||||
/* We don't simplify lambdas when we visit `let`-expressions. */
|
||||
|
|
@ -672,7 +731,7 @@ class csimp_fn {
|
|||
lean_assert(m_fvars.size() == saved_fvars_size);
|
||||
}
|
||||
|
||||
if (is_fvar(e) && entries.empty() && fvar_name(e) == fvar_name(fvar)) {
|
||||
if (entries.empty() && e == x) {
|
||||
/* `let x := v in x` ==> `v` */
|
||||
e = val;
|
||||
collect_used(val, e_fvars);
|
||||
|
|
@ -684,36 +743,38 @@ class csimp_fn {
|
|||
/* Float cases transformation. */
|
||||
if (m_cfg.m_float_cases) {
|
||||
/* We first create a let-declaration with all entries that depends on the current
|
||||
`fvar` which is a cases_on application. */
|
||||
`x` which is a cases_on application. */
|
||||
buffer<pair<expr, expr>> entries_dep_curr;
|
||||
buffer<pair<expr, expr>> entries_ndep_curr;
|
||||
split_entries(entries, fvar, entries_dep_curr, entries_ndep_curr);
|
||||
split_entries(entries, x, entries_dep_curr, entries_ndep_curr);
|
||||
expr new_e = mk_let(entries_dep_curr, e);
|
||||
buffer<expr> new_jps;
|
||||
if (optional<expr> new_e_opt = float_cases_on(fvar, new_e, val, new_jps)) {
|
||||
e = *new_e_opt;
|
||||
/* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */
|
||||
e_fvars.clear(); entries_fvars.clear();
|
||||
collect_used(e, e_fvars);
|
||||
/* Join points may have been generated, we move them to entries. */
|
||||
entries.clear();
|
||||
while (!new_jps.empty()) {
|
||||
expr jp_fvar = new_jps.back();
|
||||
new_jps.pop_back();
|
||||
local_decl jp_decl = m_lctx.get_local_decl(jp_fvar);
|
||||
expr jp_type = jp_decl.get_type();
|
||||
expr jp_val = *jp_decl.get_value();
|
||||
collect_used(jp_type, entries_fvars);
|
||||
collect_used(jp_val, entries_fvars);
|
||||
entries.emplace_back(jp_fvar, jp_val);
|
||||
if (is_float_cases_on_worthwhile(x, new_e, val)) {
|
||||
buffer<expr> new_jps;
|
||||
if (optional<expr> new_e_opt = float_cases_on(x, new_e, val, new_jps)) {
|
||||
e = *new_e_opt;
|
||||
/* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */
|
||||
e_fvars.clear(); entries_fvars.clear();
|
||||
collect_used(e, e_fvars);
|
||||
/* Join points may have been generated, we move them to entries. */
|
||||
entries.clear();
|
||||
while (!new_jps.empty()) {
|
||||
expr jp_fvar = new_jps.back();
|
||||
new_jps.pop_back();
|
||||
local_decl jp_decl = m_lctx.get_local_decl(jp_fvar);
|
||||
expr jp_type = jp_decl.get_type();
|
||||
expr jp_val = *jp_decl.get_value();
|
||||
collect_used(jp_type, entries_fvars);
|
||||
collect_used(jp_val, entries_fvars);
|
||||
entries.emplace_back(jp_fvar, jp_val);
|
||||
}
|
||||
/* Copy `entries_ndep_curr` to `entries` */
|
||||
for (unsigned i = 0; i < entries_ndep_curr.size(); i++) {
|
||||
pair<expr, expr> const & ndep_entry = entries_ndep_curr[i];
|
||||
entries.push_back(ndep_entry);
|
||||
collect_used(ndep_entry.second, entries_fvars);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
/* Copy `entries_ndep_curr` to `entries` */
|
||||
for (unsigned i = 0; i < entries_ndep_curr.size(); i++) {
|
||||
pair<expr, expr> const & ndep_entry = entries_ndep_curr[i];
|
||||
entries.push_back(ndep_entry);
|
||||
collect_used(ndep_entry.second, entries_fvars);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
val = visit_cases_default(val);
|
||||
|
|
@ -721,26 +782,26 @@ class csimp_fn {
|
|||
}
|
||||
|
||||
if (e_is_cases && used_in_e) {
|
||||
optional<unsigned> minor_idx = used_in_one_minor(e, fvar);
|
||||
optional<unsigned> minor_idx = used_in_one_minor(e, x);
|
||||
if (minor_idx && !used_in_entries) {
|
||||
/* If fvar is only used in only one minor declaration,
|
||||
/* If x is only used in only one minor declaration,
|
||||
and is *not* used in any expression at entries */
|
||||
if (modified_val) {
|
||||
/* We need to create a new free variable since the new
|
||||
simplified value `val` */
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), decl.get_user_name(), type, val);
|
||||
e = replace_fvar(e, fvar, new_fvar);
|
||||
fvar = new_fvar;
|
||||
expr new_x = m_lctx.mk_local_decl(ngen(), x_decl.get_user_name(), type, val);
|
||||
e = replace_fvar(e, x, new_x);
|
||||
x = new_x;
|
||||
}
|
||||
collect_used(type, e_fvars);
|
||||
collect_used(val, e_fvars);
|
||||
e = move_let_to_minor(e, *minor_idx, fvar);
|
||||
e = move_let_to_minor(e, *minor_idx, x);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
collect_used(type, entries_fvars);
|
||||
collect_used(val, entries_fvars);
|
||||
entries.emplace_back(fvar, val);
|
||||
entries.emplace_back(x, val);
|
||||
}
|
||||
return mk_let(entries, e);
|
||||
}
|
||||
|
|
@ -892,7 +953,7 @@ class csimp_fn {
|
|||
expr const & c = get_app_args(e, args);
|
||||
/* simplify minor premises */
|
||||
unsigned minor_idx; unsigned minors_end;
|
||||
std::tie(minor_idx, minors_end) = get_cases_on_minors_range(const_name(c));
|
||||
std::tie(minor_idx, minors_end) = get_cases_on_minors_range(env(), const_name(c));
|
||||
for (; minor_idx < minors_end; minor_idx++) {
|
||||
expr minor = args[minor_idx];
|
||||
unsigned saved_fvars_size = m_fvars.size();
|
||||
|
|
@ -933,7 +994,7 @@ class csimp_fn {
|
|||
buffer<expr> args;
|
||||
expr const & c = get_app_args(e, args);
|
||||
lean_assert(is_constant(c));
|
||||
inductive_val I_val = env().get(const_name(c).get_prefix()).to_inductive_val();
|
||||
inductive_val I_val = get_cases_on_inductive_val(env(), c);
|
||||
unsigned major_idx = I_val.get_nparams() + 1 /* typeformer/motive */ + I_val.get_nindices();
|
||||
lean_assert(major_idx < args.size());
|
||||
expr const & major = find(args[major_idx]);
|
||||
|
|
@ -994,8 +1055,7 @@ class csimp_fn {
|
|||
b_1 := lc_cast a_1
|
||||
...
|
||||
b_n := lc_cast a_n
|
||||
e := g b_1 ... b_n
|
||||
*/
|
||||
e := g b_1 ... b_n */
|
||||
expr const & g = app_arg(fn);
|
||||
buffer<expr> args;
|
||||
get_app_args(e, args);
|
||||
|
|
|
|||
|
|
@ -125,6 +125,23 @@ bool is_cases_on_recursor(environment const & env, name const & n) {
|
|||
return ::lean::is_aux_recursor(env, n) && n.get_string() == "cases_on";
|
||||
}
|
||||
|
||||
expr get_cases_on_app_major(environment const & env, expr const & c) {
|
||||
lean_assert(is_cases_on_app(env, c));
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(c, args);
|
||||
inductive_val I_val = get_cases_on_inductive_val(env, fn);
|
||||
return args[I_val.get_nparams() + 1 /* motive */ + I_val.get_nindices()];
|
||||
}
|
||||
|
||||
pair<unsigned, unsigned> get_cases_on_minors_range(environment const & env, name const & c) {
|
||||
inductive_val I_val = get_cases_on_inductive_val(env, c);
|
||||
unsigned nparams = I_val.get_nparams();
|
||||
unsigned nindices = I_val.get_nindices();
|
||||
unsigned nminors = I_val.get_ncnstrs();
|
||||
unsigned first_minor_idx = nparams + 1 /*motive*/ + nindices + 1 /* major */;
|
||||
return mk_pair(first_minor_idx, first_minor_idx + nminors);
|
||||
}
|
||||
|
||||
expr mk_lc_unreachable(type_checker::state & s, local_ctx const & lctx, expr const & type) {
|
||||
type_checker tc(s, lctx);
|
||||
level lvl = sort_level(tc.ensure_type(type));
|
||||
|
|
|
|||
|
|
@ -30,10 +30,26 @@ expr unfold_macro_defs(environment const & env, expr const & e);
|
|||
inline bool is_lc_mdata(expr const &) { return false; }
|
||||
|
||||
bool is_cases_on_recursor(environment const & env, name const & n);
|
||||
/* Return the `inductive_val` for the cases_on constant `c`. */
|
||||
inline inductive_val get_cases_on_inductive_val(environment const & env, name const & c) {
|
||||
lean_assert(is_cases_on_recursor(env, c));
|
||||
return env.get(c.get_prefix()).to_inductive_val();
|
||||
}
|
||||
inline inductive_val get_cases_on_inductive_val(environment const & env, expr const & c) {
|
||||
lean_assert(is_constant(c));
|
||||
return get_cases_on_inductive_val(env, const_name(c));
|
||||
}
|
||||
inline bool is_cases_on_app(environment const & env, expr const & e) {
|
||||
expr const & fn = get_app_fn(e);
|
||||
return is_constant(fn) && is_cases_on_recursor(env, const_name(fn));
|
||||
}
|
||||
/* Return the major premise of a cases_on-application.
|
||||
\pre is_cases_on_app(env, c) */
|
||||
expr get_cases_on_app_major(environment const & env, expr const & c);
|
||||
/* Return the pair `(b, e)` such that `i in [b, e)` is argument `i` in a `c` cases_on
|
||||
application is a minor premise.
|
||||
\pre is_cases_on_recursor(env, c) */
|
||||
pair<unsigned, unsigned> get_cases_on_minors_range(environment const & env, name const & c);
|
||||
|
||||
inline bool is_lc_unreachable_app(expr const & e) { return is_app_of(e, get_lc_unreachable_name(), 1); }
|
||||
inline bool is_lc_proof_app(expr const & e) { return is_app_of(e, get_lc_proof_name(), 1); }
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue