feat(library/compiler): add is_float_cases_on_worthwhile predicate and cleanup

This commit is contained in:
Leonardo de Moura 2018-09-27 13:12:20 -07:00
parent 78de3de764
commit d880b1c640
3 changed files with 160 additions and 67 deletions

View file

@ -23,16 +23,16 @@ csimp_cfg::csimp_cfg() {
m_inline = true;
m_inline_threshold = 4;
m_float_cases_app = true;
m_float_cases = false;
m_float_cases_jp_threshold = 3;
m_float_cases_jp_branch_threshold = 3;
m_inline_jp_threshold = 4;
m_float_cases = true;
m_float_cases_jp_threshold = 4;
m_float_cases_jp_branch_threshold = 2;
m_inline_jp_threshold = 2;
}
class csimp_fn {
type_checker::state m_st;
local_ctx m_lctx;
csimp_cfg const & m_cfg;
csimp_cfg m_cfg;
buffer<expr> m_fvars;
name m_x;
name m_j;
@ -127,16 +127,6 @@ class csimp_fn {
return fvar;
}
/* Given the `cases_on` application, return [first_minor_idx, first_minor_idx + nminors) */
pair<unsigned, unsigned> get_cases_on_minors_range(name const & cases) {
inductive_val I_val = env().get(cases.get_prefix()).to_inductive_val();
unsigned nparams = I_val.get_nparams();
unsigned nindices = I_val.get_nindices();
unsigned nminors = I_val.get_ncnstrs();
unsigned first_minor_idx = nparams + 1 /*motive*/ + nindices + 1 /* major */;
return mk_pair(first_minor_idx, first_minor_idx + nminors);
}
/* Given a cases_on application `c`, return `some idx` iff `fvar` only occurs
in the argument `idx`, this argument is a minor premise. */
optional<unsigned> used_in_one_minor(expr const & c, expr const & fvar) {
@ -145,7 +135,7 @@ class csimp_fn {
buffer<expr> args;
expr const & c_fn = get_app_args(c, args);
unsigned minors_begin; unsigned minors_end;
std::tie(minors_begin, minors_end) = get_cases_on_minors_range(const_name(c_fn));
std::tie(minors_begin, minors_end) = get_cases_on_minors_range(env(), const_name(c_fn));
unsigned i = 0;
for (; i < minors_begin; i++) {
if (has_fvar(args[i], fvar)) {
@ -237,6 +227,77 @@ class csimp_fn {
});
}
/* Return true iff the free variable `x` occurs in a projection or is the major premise of
a `cases_on` application in `e`. */
bool is_proj_or_cases_on_arg_at(expr const & x, expr const & e) {
lean_assert(is_fvar(x));
if (!has_fvar(e)) return false;
bool found = false;
for_each(e, [&](expr const & s, unsigned) {
if (!has_fvar(s)) return false;
if (found) return false;
if (is_proj(s) && proj_expr(s) == x) {
found = true;
return false;
} else if (is_cases_on_app(env(), s) && get_cases_on_app_major(env(), s) == x) {
found = true;
return false;
}
return true;
});
return found;
}
/* Auxiliary method for `may_return_constructor`. `visited` contains the set of join points that
have already been visited. */
bool may_return_constructor_core(expr e, name_set & visited) {
while (true) {
if (is_lambda(e))
e = binding_body(e);
else if (is_let(e))
e = let_body(e);
else
break;
}
if (is_constructor_app(env(), e)) {
return true;
} else if (is_cases_on_app(env(), e)) {
buffer<expr> args;
expr const & fn = get_app_args(e, args);
unsigned begin_minors; unsigned end_minors;
std::tie(begin_minors, end_minors) = get_cases_on_minors_range(env(), const_name(fn));
for (unsigned i = begin_minors; i < end_minors; i++) {
if (may_return_constructor_core(args[i], visited))
return true;
}
return false;
} else if (is_join_point_app(e)) {
expr const & fn = get_app_fn(e);
lean_assert(is_fvar(fn));
if (visited.find(fvar_name(fn)) != visited.end())
return false;
visited.insert(fvar_name(fn));
local_decl decl = m_lctx.get_local_decl(fn);
return may_return_constructor_core(*decl.get_value(), visited);
} else {
return false;
}
}
/* Return true if `e` may return a constructor application. We say "may" because the
result may be a `cases_on`-application and we return true if one of the branches return a constructor. */
bool may_return_constructor(expr const & e) {
name_set visited;
return may_return_constructor_core(e, visited);
}
bool is_float_cases_on_worthwhile(expr const & x, expr const & e, expr const & c) {
lean_assert(is_cases_on_app(env(), c));
return
is_proj_or_cases_on_arg_at(x, e) &&
may_return_constructor(c);
}
/* The `float_cases_on` transformation may produce code duplication.
The term `e` is "copied" in each branch of the the `cases_on` expression `c`.
This method creates one (or more) join-point(s) for `e` (if needed).
@ -244,7 +305,7 @@ class csimp_fn {
optional<expr> mk_join_point_float_cases_on(expr const & fvar, expr const & e, expr const & c, buffer<expr> & new_jps) {
lean_assert(is_cases_on_app(env(), c));
expr const & c_fn = get_app_fn(c);
inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val();
inductive_val I_val = get_cases_on_inductive_val(env(), c_fn);
if (I_val.get_ncnstrs() == 1) {
/* `c` has only one case. So, only one copy of `e` may be created. */
return some_expr(e);
@ -258,7 +319,7 @@ class csimp_fn {
bool modified = false;
unsigned saved_fvars_size = m_fvars.size();
unsigned begin_minors; unsigned end_minors;
std::tie(begin_minors, end_minors) = get_cases_on_minors_range(const_name(fn));
std::tie(begin_minors, end_minors) = get_cases_on_minors_range(env(), const_name(fn));
for (unsigned minor_idx = begin_minors; minor_idx < end_minors; minor_idx++) {
expr minor = args[minor_idx];
if (get_lcnf_size(env(), minor) > m_cfg.m_float_cases_jp_branch_threshold) {
@ -467,7 +528,7 @@ class csimp_fn {
expr result_type = whnf_infer_type(e);
buffer<expr> c_args;
expr c_fn = get_app_args(c, c_args);
inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val();
inductive_val I_val = get_cases_on_inductive_val(env(), c_fn);
unsigned motive_idx = I_val.get_nparams();
unsigned first_index = motive_idx + 1;
unsigned nindices = I_val.get_nindices();
@ -549,8 +610,7 @@ class csimp_fn {
x_1 := w_1
in e
```
The values `w_i` are the "simplified values" for the let-declaration `x_i`.
*/
The values `w_i` are the "simplified values" for the let-declaration `x_i`. */
expr mk_let(buffer<pair<expr, expr>> const & entries, expr e) {
buffer<expr> fvars;
buffer<name> user_names;
@ -604,8 +664,7 @@ class csimp_fn {
...
x_1 := w_1
in e
```
*/
``` */
void split_entries(buffer<pair<expr, expr>> const & entries,
expr const & x,
buffer<pair<expr, expr>> & entries_dep_x,
@ -652,17 +711,17 @@ class csimp_fn {
collect_used(e, e_fvars);
bool e_is_cases = is_cases_on_app(env(), e);
while (m_fvars.size() > saved_fvars_size) {
expr fvar = m_fvars.back();
expr x = m_fvars.back();
m_fvars.pop_back();
bool used_in_e = (e_fvars.find(fvar_name(fvar)) != e_fvars.end());
bool used_in_entries = (entries_fvars.find(fvar_name(fvar)) != entries_fvars.end());
bool used_in_e = (e_fvars.find(fvar_name(x)) != e_fvars.end());
bool used_in_entries = (entries_fvars.find(fvar_name(x)) != entries_fvars.end());
if (!used_in_e && !used_in_entries) {
/* Skip unused variables */
continue;
}
local_decl decl = m_lctx.get_local_decl(fvar);
expr type = decl.get_type();
expr val = *decl.get_value();
local_decl x_decl = m_lctx.get_local_decl(x);
expr type = x_decl.get_type();
expr val = *x_decl.get_value();
bool modified_val = false;
if (is_lambda(val)) {
/* We don't simplify lambdas when we visit `let`-expressions. */
@ -672,7 +731,7 @@ class csimp_fn {
lean_assert(m_fvars.size() == saved_fvars_size);
}
if (is_fvar(e) && entries.empty() && fvar_name(e) == fvar_name(fvar)) {
if (entries.empty() && e == x) {
/* `let x := v in x` ==> `v` */
e = val;
collect_used(val, e_fvars);
@ -684,36 +743,38 @@ class csimp_fn {
/* Float cases transformation. */
if (m_cfg.m_float_cases) {
/* We first create a let-declaration with all entries that depends on the current
`fvar` which is a cases_on application. */
`x` which is a cases_on application. */
buffer<pair<expr, expr>> entries_dep_curr;
buffer<pair<expr, expr>> entries_ndep_curr;
split_entries(entries, fvar, entries_dep_curr, entries_ndep_curr);
split_entries(entries, x, entries_dep_curr, entries_ndep_curr);
expr new_e = mk_let(entries_dep_curr, e);
buffer<expr> new_jps;
if (optional<expr> new_e_opt = float_cases_on(fvar, new_e, val, new_jps)) {
e = *new_e_opt;
/* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */
e_fvars.clear(); entries_fvars.clear();
collect_used(e, e_fvars);
/* Join points may have been generated, we move them to entries. */
entries.clear();
while (!new_jps.empty()) {
expr jp_fvar = new_jps.back();
new_jps.pop_back();
local_decl jp_decl = m_lctx.get_local_decl(jp_fvar);
expr jp_type = jp_decl.get_type();
expr jp_val = *jp_decl.get_value();
collect_used(jp_type, entries_fvars);
collect_used(jp_val, entries_fvars);
entries.emplace_back(jp_fvar, jp_val);
if (is_float_cases_on_worthwhile(x, new_e, val)) {
buffer<expr> new_jps;
if (optional<expr> new_e_opt = float_cases_on(x, new_e, val, new_jps)) {
e = *new_e_opt;
/* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */
e_fvars.clear(); entries_fvars.clear();
collect_used(e, e_fvars);
/* Join points may have been generated, we move them to entries. */
entries.clear();
while (!new_jps.empty()) {
expr jp_fvar = new_jps.back();
new_jps.pop_back();
local_decl jp_decl = m_lctx.get_local_decl(jp_fvar);
expr jp_type = jp_decl.get_type();
expr jp_val = *jp_decl.get_value();
collect_used(jp_type, entries_fvars);
collect_used(jp_val, entries_fvars);
entries.emplace_back(jp_fvar, jp_val);
}
/* Copy `entries_ndep_curr` to `entries` */
for (unsigned i = 0; i < entries_ndep_curr.size(); i++) {
pair<expr, expr> const & ndep_entry = entries_ndep_curr[i];
entries.push_back(ndep_entry);
collect_used(ndep_entry.second, entries_fvars);
}
continue;
}
/* Copy `entries_ndep_curr` to `entries` */
for (unsigned i = 0; i < entries_ndep_curr.size(); i++) {
pair<expr, expr> const & ndep_entry = entries_ndep_curr[i];
entries.push_back(ndep_entry);
collect_used(ndep_entry.second, entries_fvars);
}
continue;
}
}
val = visit_cases_default(val);
@ -721,26 +782,26 @@ class csimp_fn {
}
if (e_is_cases && used_in_e) {
optional<unsigned> minor_idx = used_in_one_minor(e, fvar);
optional<unsigned> minor_idx = used_in_one_minor(e, x);
if (minor_idx && !used_in_entries) {
/* If fvar is only used in only one minor declaration,
/* If x is only used in only one minor declaration,
and is *not* used in any expression at entries */
if (modified_val) {
/* We need to create a new free variable since the new
simplified value `val` */
expr new_fvar = m_lctx.mk_local_decl(ngen(), decl.get_user_name(), type, val);
e = replace_fvar(e, fvar, new_fvar);
fvar = new_fvar;
expr new_x = m_lctx.mk_local_decl(ngen(), x_decl.get_user_name(), type, val);
e = replace_fvar(e, x, new_x);
x = new_x;
}
collect_used(type, e_fvars);
collect_used(val, e_fvars);
e = move_let_to_minor(e, *minor_idx, fvar);
e = move_let_to_minor(e, *minor_idx, x);
continue;
}
}
collect_used(type, entries_fvars);
collect_used(val, entries_fvars);
entries.emplace_back(fvar, val);
entries.emplace_back(x, val);
}
return mk_let(entries, e);
}
@ -892,7 +953,7 @@ class csimp_fn {
expr const & c = get_app_args(e, args);
/* simplify minor premises */
unsigned minor_idx; unsigned minors_end;
std::tie(minor_idx, minors_end) = get_cases_on_minors_range(const_name(c));
std::tie(minor_idx, minors_end) = get_cases_on_minors_range(env(), const_name(c));
for (; minor_idx < minors_end; minor_idx++) {
expr minor = args[minor_idx];
unsigned saved_fvars_size = m_fvars.size();
@ -933,7 +994,7 @@ class csimp_fn {
buffer<expr> args;
expr const & c = get_app_args(e, args);
lean_assert(is_constant(c));
inductive_val I_val = env().get(const_name(c).get_prefix()).to_inductive_val();
inductive_val I_val = get_cases_on_inductive_val(env(), c);
unsigned major_idx = I_val.get_nparams() + 1 /* typeformer/motive */ + I_val.get_nindices();
lean_assert(major_idx < args.size());
expr const & major = find(args[major_idx]);
@ -994,8 +1055,7 @@ class csimp_fn {
b_1 := lc_cast a_1
...
b_n := lc_cast a_n
e := g b_1 ... b_n
*/
e := g b_1 ... b_n */
expr const & g = app_arg(fn);
buffer<expr> args;
get_app_args(e, args);

View file

@ -125,6 +125,23 @@ bool is_cases_on_recursor(environment const & env, name const & n) {
return ::lean::is_aux_recursor(env, n) && n.get_string() == "cases_on";
}
expr get_cases_on_app_major(environment const & env, expr const & c) {
lean_assert(is_cases_on_app(env, c));
buffer<expr> args;
expr const & fn = get_app_args(c, args);
inductive_val I_val = get_cases_on_inductive_val(env, fn);
return args[I_val.get_nparams() + 1 /* motive */ + I_val.get_nindices()];
}
pair<unsigned, unsigned> get_cases_on_minors_range(environment const & env, name const & c) {
inductive_val I_val = get_cases_on_inductive_val(env, c);
unsigned nparams = I_val.get_nparams();
unsigned nindices = I_val.get_nindices();
unsigned nminors = I_val.get_ncnstrs();
unsigned first_minor_idx = nparams + 1 /*motive*/ + nindices + 1 /* major */;
return mk_pair(first_minor_idx, first_minor_idx + nminors);
}
expr mk_lc_unreachable(type_checker::state & s, local_ctx const & lctx, expr const & type) {
type_checker tc(s, lctx);
level lvl = sort_level(tc.ensure_type(type));

View file

@ -30,10 +30,26 @@ expr unfold_macro_defs(environment const & env, expr const & e);
inline bool is_lc_mdata(expr const &) { return false; }
bool is_cases_on_recursor(environment const & env, name const & n);
/* Return the `inductive_val` for the cases_on constant `c`. */
inline inductive_val get_cases_on_inductive_val(environment const & env, name const & c) {
lean_assert(is_cases_on_recursor(env, c));
return env.get(c.get_prefix()).to_inductive_val();
}
inline inductive_val get_cases_on_inductive_val(environment const & env, expr const & c) {
lean_assert(is_constant(c));
return get_cases_on_inductive_val(env, const_name(c));
}
inline bool is_cases_on_app(environment const & env, expr const & e) {
expr const & fn = get_app_fn(e);
return is_constant(fn) && is_cases_on_recursor(env, const_name(fn));
}
/* Return the major premise of a cases_on-application.
\pre is_cases_on_app(env, c) */
expr get_cases_on_app_major(environment const & env, expr const & c);
/* Return the pair `(b, e)` such that `i in [b, e)` is argument `i` in a `c` cases_on
application is a minor premise.
\pre is_cases_on_recursor(env, c) */
pair<unsigned, unsigned> get_cases_on_minors_range(environment const & env, name const & c);
inline bool is_lc_unreachable_app(expr const & e) { return is_app_of(e, get_lc_unreachable_name(), 1); }
inline bool is_lc_proof_app(expr const & e) { return is_app_of(e, get_lc_proof_name(), 1); }