feat(library/compiler/csimp): add float_cases_on

This commit is contained in:
Leonardo de Moura 2018-09-24 17:21:31 -07:00
parent ee43d4a20a
commit 017261960c
2 changed files with 123 additions and 78 deletions

View file

@ -30,12 +30,25 @@ class csimp_fn {
name_generator & ngen() { return m_st.ngen(); }
/* Very simple predicate used to decide whether we should inline joint-points or not.
TODO(Leo): improve */
bool is_small(expr const & e) const {
if (is_app(e) && !is_cases_on_app(env(), e))
return true;
if (is_lambda(e))
return is_small(binding_body(e));
return false;
}
expr find(expr const & e, bool skip_mdata = true) const {
if (is_fvar(e)) {
if (optional<local_decl> decl = m_lctx.find_local_decl(e)) {
if (optional<expr> v = decl->get_value())
if (optional<expr> v = decl->get_value()) {
if (!is_join_point_name(decl->get_user_name()))
return find(*v, skip_mdata);
else if (is_small(*v))
return find(*v, skip_mdata);
}
}
} else if (is_mdata(e) && skip_mdata) {
return find(mdata_expr(e), true);
@ -43,6 +56,8 @@ class csimp_fn {
return e;
}
type_checker tc() { return type_checker(m_st, m_lctx); }
expr infer_type(expr const & e) { return type_checker(m_st, m_lctx).infer(e); }
expr whnf(expr const & e) { return type_checker(m_st, m_lctx).whnf(e); }
@ -166,7 +181,7 @@ class csimp_fn {
The main `mk_let` does not modify the user-name/type, but we store them at `entries` just for
convenience and to avoid additional accesses to `m_lctx`.
Remark: the buffers `fvars` and `entries` are reverted by this method. */
Remark: `fvars` and `entries` are not modified. */
expr mk_let(buffer<expr> & fvars, buffer<std::tuple<name, expr, expr>> & entries, expr e) {
lean_assert(fvars.size() == entries.size());
std::reverse(fvars.begin(), fvars.end());
@ -179,9 +194,94 @@ class csimp_fn {
expr new_type = abstract(std::get<1>(entries[i]), i, fvars.data());
e = ::lean::mk_let(std::get<0>(entries[i]), new_type, new_value, e);
}
/* Restore `fvars` and `entries` */
std::reverse(fvars.begin(), fvars.end());
std::reverse(entries.begin(), entries.end());
return e;
}
/* Float cases transformation
```
let x := cases_on m (fun y_1, let ... in e_1)
...
(fun y_n, let ... in e_n))
in e
```
==>
```
cases_on m (fun y_1, let ... x := e_1 in e)
...
(fun y_n, let ... x := e_n in e)
``` */
optional<expr> float_cases_on(expr const & fvar, expr const & c, expr const & e) {
lean_assert(is_cases_on_app(env(), c));
local_decl fvar_decl = m_lctx.get_local_decl(fvar);
expr result_type = whnf_infer_type(e);
buffer<expr> c_args;
expr c_fn = get_app_args(c, c_args);
inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val();
unsigned motive_idx = I_val.get_nparams();
unsigned first_index = motive_idx + 1;
unsigned nindices = I_val.get_nindices();
unsigned major_idx = first_index + nindices;
unsigned first_minor_idx = major_idx + 1;
unsigned nminors = I_val.get_ncnstrs();
/* Update motive */
{
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> fvars;
expr motive = c_args[motive_idx];
expr motive_type = whnf_infer_type(motive);
for (unsigned i = 0; i < nindices + 1; i++) {
lean_assert(is_pi(motive_type));
expr fvar = m_lctx.mk_local_decl(ngen(), binding_name(motive_type), binding_domain(motive_type), binding_info(motive_type));
fvars.push_back(fvar);
motive_type = whnf(instantiate(binding_body(motive_type), fvar));
}
level result_lvl = sort_level(tc().ensure_type(result_type));
expr new_motive = m_lctx.mk_lambda(fvars, result_type);
c_args[motive_idx] = new_motive;
/* We need to update the resultant universe. */
levels new_cases_lvls = levels(result_lvl, tail(const_levels(c_fn)));
c_fn = update_constant(c_fn, new_cases_lvls);
}
flet<local_ctx> save_lctx(m_lctx, m_lctx);
unsigned saved_fvars_size = m_fvars.size();
/* Update minor premises */
for (unsigned i = 0; i < nminors; i++) {
unsigned minor_idx = first_minor_idx + i;
expr minor = c_args[minor_idx];
buffer<expr> minor_fvars;
unsigned old_fvars_size = m_fvars.size();
while (is_lambda(minor)) {
expr new_d = instantiate_rev(binding_domain(minor), minor_fvars.size(), minor_fvars.data());
expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(minor), new_d, binding_info(minor));
minor_fvars.push_back(new_fvar);
minor = binding_body(minor);
}
expr minor_val = visit(instantiate_rev(minor, minor_fvars.size(), minor_fvars.data()), true);
/* TODO(Leo): We need to preserve join points. */
expr minor_val_type = infer_type(minor_val);
minor_val = mk_cast(minor_val_type, fvar_decl.get_type(), minor_val);
expr new_fvar = m_lctx.mk_local_decl(ngen(), fvar_decl.get_user_name(), fvar_decl.get_type(), minor_val);
m_fvars.push_back(new_fvar);
expr new_minor;
if (optional<expr> new_minor_opt = replace_fvar_with(m_st, m_lctx, e, fvar, new_fvar)) {
new_minor = *new_minor_opt;
} else {
m_fvars.resize(saved_fvars_size);
return none_expr(); /* Failed to produce type correct `new_minor` */
}
new_minor = visit(new_minor, false);
expr new_minor_type = infer_type(new_minor);
new_minor = mk_cast(new_minor_type, result_type, new_minor);
new_minor = mk_let(old_fvars_size, new_minor);
new_minor = m_lctx.mk_lambda(minor_fvars, new_minor);
c_args[minor_idx] = new_minor;
}
return some_expr(mk_app(c_fn, c_args));
}
/* Create a let-expression with body `e`, and
all "used" let-declarations `m_fvars[i]` for `i in [old_fvars_size, m_fvars.size)`.
@ -229,9 +329,25 @@ class csimp_fn {
continue;
}
if (is_cases_on_app(env(), val)) {
// TODO(Leo);
} else if (e_is_cases && used_in_e) {
/* TODO(Leo): enable "float cases" as soon as it has descent performance */
if (false && is_cases_on_app(env(), val)) {
/* Float cases transformation. */
DEBUG_CODE(unsigned saved_fvars_size = m_fvars.size(););
expr new_e = mk_let(fvars, entries, e);
/* TODO(Leo): create new joint point if `e` is too big */
if (optional<expr> new_e_opt = float_cases_on(fvar, val, new_e)) {
e = *new_e_opt;
fvars.clear(); entries.clear();
e_fvars.clear(); entries_fvars.clear();
collect_used(e, e_fvars);
lean_assert(m_fvars.size() == saved_fvars_size);
continue;
} else {
lean_assert(m_fvars.size() == saved_fvars_size);
}
}
if (e_is_cases && used_in_e) {
optional<unsigned> minor_idx = used_in_one_minor(e, fvar);
if (minor_idx && !used_in_entries) {
/* If fvar is only used in only one minor declaration,
@ -249,6 +365,7 @@ class csimp_fn {
continue;
}
}
fvars.push_back(fvar);
collect_used(type, entries_fvars);
collect_used(val, entries_fvars);
@ -501,78 +618,6 @@ class csimp_fn {
}
}
// expr distrib_app_cases(expr const & fn, expr const & e) {
// lean_assert(is_cases_on_app(env(), fn));
// lean_assert(is_eqp(find(get_app_fn(e)), fn));
// expr result_type = infer_type(e);
// buffer<expr> args;
// get_app_args(e, args);
// buffer<expr> cases_args;
// expr cases = get_app_args(fn, cases_args);
// lean_assert(is_constant(cases));
// inductive_val I_val = env().get(const_name(cases).get_prefix()).to_inductive_val();
// unsigned motive_idx = I_val.get_nparams();
// unsigned first_index = motive_idx + 1;
// unsigned nindices = I_val.get_nindices();
// unsigned major_idx = first_index + nindices;
// unsigned first_minor_idx = major_idx + 1;
// unsigned nminors = length(I_val.get_cnstrs());
// /* Infer argument types */
// buffer<expr> arg_types;
// {
// type_checker tc(m_st, m_lctx);
// for (expr const & arg : args) {
// arg_types.push_back(tc.infer(arg));
// }
// }
// /* Update motive */
// {
// flet<local_ctx> save_lctx(m_lctx, m_lctx);
// buffer<expr> fvars;
// expr motive = cases_args[motive_idx];
// expr motive_type = whnf_infer_type(motive);
// for (unsigned i = 0; i < nindices + 1; i++) {
// lean_assert(is_pi(motive_type));
// expr fvar = m_lctx.mk_local_decl(ngen(), binding_name(motive_type), binding_domain(motive_type), binding_info(motive_type));
// fvars.push_back(fvar);
// motive_type = whnf(instantiate(binding_body(motive_type), fvar));
// }
// level result_lvl = sort_level(type_checker(env(), m_lctx).ensure_type(result_type));
// expr new_motive = m_lctx.mk_lambda(fvars, result_type);
// cases_args[motive_idx] = new_motive;
// /* We need to update the resultant universe. */
// levels new_cases_lvls = levels(result_lvl, tail(const_levels(cases)));
// cases = update_constant(cases, new_cases_lvls);
// }
// /* Update minor premises */
// for (unsigned i = 0; i < nminors; i++) {
// unsigned minor_idx = first_minor_idx + i;
// expr minor = cases_args[minor_idx];
// flet<local_ctx> save_lctx(m_lctx, m_lctx);
// buffer<expr> minor_fvars;
// unsigned old_fvars_size = m_fvars.size();
// while (is_lambda(minor)) {
// expr new_d = instantiate_rev(binding_domain(minor), minor_fvars.size(), minor_fvars.data());
// expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(minor), new_d, binding_info(minor));
// minor_fvars.push_back(new_fvar);
// minor = binding_body(minor);
// }
// expr new_minor = visit(instantiate_rev(minor, minor_fvars.size(), minor_fvars.data()));
// for (unsigned i = 0; i < args.size(); i++) {
// expr new_minor_type = whnf_infer_type(new_minor);
// lean_assert(is_pi(new_minor_type));
// new_minor = mk_app(new_minor, mk_cast(arg_types[i], binding_domain(new_minor_type), args[i]));
// }
// new_minor = visit_let_value(new_minor);
// type_checker tc(m_st, m_lctx);
// new_minor = mk_cast(tc, tc.infer(new_minor), result_type, new_minor);
// new_minor = mk_let(old_fvars_size, new_minor);
// new_minor = m_lctx.mk_lambda(minor_fvars, new_minor);
// cases_args[minor_idx] = new_minor;
// }
// return mk_let_decl(mk_app(cases, cases_args));
// }
expr reduce_lc_cast(expr const & e) {
buffer<expr> args;
expr const & cast_fn1 = get_app_args(e, args);

View file

@ -2,7 +2,7 @@ import init.lean.parser.parsec
import init.control.coroutine
universes u v w r s
set_option trace.compiler.lcnf true
set_option trace.compiler.stage1 true
-- set_option pp.implicit true
set_option pp.binder_types false
set_option pp.proofs true