feat(library/compiler/csimp): add float_cases_on
This commit is contained in:
parent
ee43d4a20a
commit
017261960c
2 changed files with 123 additions and 78 deletions
|
|
@ -30,12 +30,25 @@ class csimp_fn {
|
|||
|
||||
name_generator & ngen() { return m_st.ngen(); }
|
||||
|
||||
/* Very simple predicate used to decide whether we should inline joint-points or not.
|
||||
TODO(Leo): improve */
|
||||
bool is_small(expr const & e) const {
|
||||
if (is_app(e) && !is_cases_on_app(env(), e))
|
||||
return true;
|
||||
if (is_lambda(e))
|
||||
return is_small(binding_body(e));
|
||||
return false;
|
||||
}
|
||||
|
||||
expr find(expr const & e, bool skip_mdata = true) const {
|
||||
if (is_fvar(e)) {
|
||||
if (optional<local_decl> decl = m_lctx.find_local_decl(e)) {
|
||||
if (optional<expr> v = decl->get_value())
|
||||
if (optional<expr> v = decl->get_value()) {
|
||||
if (!is_join_point_name(decl->get_user_name()))
|
||||
return find(*v, skip_mdata);
|
||||
else if (is_small(*v))
|
||||
return find(*v, skip_mdata);
|
||||
}
|
||||
}
|
||||
} else if (is_mdata(e) && skip_mdata) {
|
||||
return find(mdata_expr(e), true);
|
||||
|
|
@ -43,6 +56,8 @@ class csimp_fn {
|
|||
return e;
|
||||
}
|
||||
|
||||
type_checker tc() { return type_checker(m_st, m_lctx); }
|
||||
|
||||
expr infer_type(expr const & e) { return type_checker(m_st, m_lctx).infer(e); }
|
||||
|
||||
expr whnf(expr const & e) { return type_checker(m_st, m_lctx).whnf(e); }
|
||||
|
|
@ -166,7 +181,7 @@ class csimp_fn {
|
|||
The main `mk_let` does not modify the user-name/type, but we store them at `entries` just for
|
||||
convenience and to avoid additional accesses to `m_lctx`.
|
||||
|
||||
Remark: the buffers `fvars` and `entries` are reverted by this method. */
|
||||
Remark: `fvars` and `entries` are not modified. */
|
||||
expr mk_let(buffer<expr> & fvars, buffer<std::tuple<name, expr, expr>> & entries, expr e) {
|
||||
lean_assert(fvars.size() == entries.size());
|
||||
std::reverse(fvars.begin(), fvars.end());
|
||||
|
|
@ -179,9 +194,94 @@ class csimp_fn {
|
|||
expr new_type = abstract(std::get<1>(entries[i]), i, fvars.data());
|
||||
e = ::lean::mk_let(std::get<0>(entries[i]), new_type, new_value, e);
|
||||
}
|
||||
/* Restore `fvars` and `entries` */
|
||||
std::reverse(fvars.begin(), fvars.end());
|
||||
std::reverse(entries.begin(), entries.end());
|
||||
return e;
|
||||
}
|
||||
|
||||
/* Float cases transformation
|
||||
```
|
||||
let x := cases_on m (fun y_1, let ... in e_1)
|
||||
...
|
||||
(fun y_n, let ... in e_n))
|
||||
in e
|
||||
```
|
||||
==>
|
||||
```
|
||||
cases_on m (fun y_1, let ... x := e_1 in e)
|
||||
...
|
||||
(fun y_n, let ... x := e_n in e)
|
||||
``` */
|
||||
optional<expr> float_cases_on(expr const & fvar, expr const & c, expr const & e) {
|
||||
lean_assert(is_cases_on_app(env(), c));
|
||||
local_decl fvar_decl = m_lctx.get_local_decl(fvar);
|
||||
expr result_type = whnf_infer_type(e);
|
||||
buffer<expr> c_args;
|
||||
expr c_fn = get_app_args(c, c_args);
|
||||
inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val();
|
||||
unsigned motive_idx = I_val.get_nparams();
|
||||
unsigned first_index = motive_idx + 1;
|
||||
unsigned nindices = I_val.get_nindices();
|
||||
unsigned major_idx = first_index + nindices;
|
||||
unsigned first_minor_idx = major_idx + 1;
|
||||
unsigned nminors = I_val.get_ncnstrs();
|
||||
/* Update motive */
|
||||
{
|
||||
flet<local_ctx> save_lctx(m_lctx, m_lctx);
|
||||
buffer<expr> fvars;
|
||||
expr motive = c_args[motive_idx];
|
||||
expr motive_type = whnf_infer_type(motive);
|
||||
for (unsigned i = 0; i < nindices + 1; i++) {
|
||||
lean_assert(is_pi(motive_type));
|
||||
expr fvar = m_lctx.mk_local_decl(ngen(), binding_name(motive_type), binding_domain(motive_type), binding_info(motive_type));
|
||||
fvars.push_back(fvar);
|
||||
motive_type = whnf(instantiate(binding_body(motive_type), fvar));
|
||||
}
|
||||
level result_lvl = sort_level(tc().ensure_type(result_type));
|
||||
expr new_motive = m_lctx.mk_lambda(fvars, result_type);
|
||||
c_args[motive_idx] = new_motive;
|
||||
/* We need to update the resultant universe. */
|
||||
levels new_cases_lvls = levels(result_lvl, tail(const_levels(c_fn)));
|
||||
c_fn = update_constant(c_fn, new_cases_lvls);
|
||||
}
|
||||
flet<local_ctx> save_lctx(m_lctx, m_lctx);
|
||||
unsigned saved_fvars_size = m_fvars.size();
|
||||
/* Update minor premises */
|
||||
for (unsigned i = 0; i < nminors; i++) {
|
||||
unsigned minor_idx = first_minor_idx + i;
|
||||
expr minor = c_args[minor_idx];
|
||||
buffer<expr> minor_fvars;
|
||||
unsigned old_fvars_size = m_fvars.size();
|
||||
while (is_lambda(minor)) {
|
||||
expr new_d = instantiate_rev(binding_domain(minor), minor_fvars.size(), minor_fvars.data());
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(minor), new_d, binding_info(minor));
|
||||
minor_fvars.push_back(new_fvar);
|
||||
minor = binding_body(minor);
|
||||
}
|
||||
expr minor_val = visit(instantiate_rev(minor, minor_fvars.size(), minor_fvars.data()), true);
|
||||
/* TODO(Leo): We need to preserve join points. */
|
||||
expr minor_val_type = infer_type(minor_val);
|
||||
minor_val = mk_cast(minor_val_type, fvar_decl.get_type(), minor_val);
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), fvar_decl.get_user_name(), fvar_decl.get_type(), minor_val);
|
||||
m_fvars.push_back(new_fvar);
|
||||
expr new_minor;
|
||||
if (optional<expr> new_minor_opt = replace_fvar_with(m_st, m_lctx, e, fvar, new_fvar)) {
|
||||
new_minor = *new_minor_opt;
|
||||
} else {
|
||||
m_fvars.resize(saved_fvars_size);
|
||||
return none_expr(); /* Failed to produce type correct `new_minor` */
|
||||
}
|
||||
new_minor = visit(new_minor, false);
|
||||
expr new_minor_type = infer_type(new_minor);
|
||||
new_minor = mk_cast(new_minor_type, result_type, new_minor);
|
||||
new_minor = mk_let(old_fvars_size, new_minor);
|
||||
new_minor = m_lctx.mk_lambda(minor_fvars, new_minor);
|
||||
c_args[minor_idx] = new_minor;
|
||||
}
|
||||
return some_expr(mk_app(c_fn, c_args));
|
||||
}
|
||||
|
||||
/* Create a let-expression with body `e`, and
|
||||
all "used" let-declarations `m_fvars[i]` for `i in [old_fvars_size, m_fvars.size)`.
|
||||
|
||||
|
|
@ -229,9 +329,25 @@ class csimp_fn {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (is_cases_on_app(env(), val)) {
|
||||
// TODO(Leo);
|
||||
} else if (e_is_cases && used_in_e) {
|
||||
/* TODO(Leo): enable "float cases" as soon as it has descent performance */
|
||||
if (false && is_cases_on_app(env(), val)) {
|
||||
/* Float cases transformation. */
|
||||
DEBUG_CODE(unsigned saved_fvars_size = m_fvars.size(););
|
||||
expr new_e = mk_let(fvars, entries, e);
|
||||
/* TODO(Leo): create new joint point if `e` is too big */
|
||||
if (optional<expr> new_e_opt = float_cases_on(fvar, val, new_e)) {
|
||||
e = *new_e_opt;
|
||||
fvars.clear(); entries.clear();
|
||||
e_fvars.clear(); entries_fvars.clear();
|
||||
collect_used(e, e_fvars);
|
||||
lean_assert(m_fvars.size() == saved_fvars_size);
|
||||
continue;
|
||||
} else {
|
||||
lean_assert(m_fvars.size() == saved_fvars_size);
|
||||
}
|
||||
}
|
||||
|
||||
if (e_is_cases && used_in_e) {
|
||||
optional<unsigned> minor_idx = used_in_one_minor(e, fvar);
|
||||
if (minor_idx && !used_in_entries) {
|
||||
/* If fvar is only used in only one minor declaration,
|
||||
|
|
@ -249,6 +365,7 @@ class csimp_fn {
|
|||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
fvars.push_back(fvar);
|
||||
collect_used(type, entries_fvars);
|
||||
collect_used(val, entries_fvars);
|
||||
|
|
@ -501,78 +618,6 @@ class csimp_fn {
|
|||
}
|
||||
}
|
||||
|
||||
// expr distrib_app_cases(expr const & fn, expr const & e) {
|
||||
// lean_assert(is_cases_on_app(env(), fn));
|
||||
// lean_assert(is_eqp(find(get_app_fn(e)), fn));
|
||||
// expr result_type = infer_type(e);
|
||||
// buffer<expr> args;
|
||||
// get_app_args(e, args);
|
||||
// buffer<expr> cases_args;
|
||||
// expr cases = get_app_args(fn, cases_args);
|
||||
// lean_assert(is_constant(cases));
|
||||
// inductive_val I_val = env().get(const_name(cases).get_prefix()).to_inductive_val();
|
||||
// unsigned motive_idx = I_val.get_nparams();
|
||||
// unsigned first_index = motive_idx + 1;
|
||||
// unsigned nindices = I_val.get_nindices();
|
||||
// unsigned major_idx = first_index + nindices;
|
||||
// unsigned first_minor_idx = major_idx + 1;
|
||||
// unsigned nminors = length(I_val.get_cnstrs());
|
||||
// /* Infer argument types */
|
||||
// buffer<expr> arg_types;
|
||||
// {
|
||||
// type_checker tc(m_st, m_lctx);
|
||||
// for (expr const & arg : args) {
|
||||
// arg_types.push_back(tc.infer(arg));
|
||||
// }
|
||||
// }
|
||||
// /* Update motive */
|
||||
// {
|
||||
// flet<local_ctx> save_lctx(m_lctx, m_lctx);
|
||||
// buffer<expr> fvars;
|
||||
// expr motive = cases_args[motive_idx];
|
||||
// expr motive_type = whnf_infer_type(motive);
|
||||
// for (unsigned i = 0; i < nindices + 1; i++) {
|
||||
// lean_assert(is_pi(motive_type));
|
||||
// expr fvar = m_lctx.mk_local_decl(ngen(), binding_name(motive_type), binding_domain(motive_type), binding_info(motive_type));
|
||||
// fvars.push_back(fvar);
|
||||
// motive_type = whnf(instantiate(binding_body(motive_type), fvar));
|
||||
// }
|
||||
// level result_lvl = sort_level(type_checker(env(), m_lctx).ensure_type(result_type));
|
||||
// expr new_motive = m_lctx.mk_lambda(fvars, result_type);
|
||||
// cases_args[motive_idx] = new_motive;
|
||||
// /* We need to update the resultant universe. */
|
||||
// levels new_cases_lvls = levels(result_lvl, tail(const_levels(cases)));
|
||||
// cases = update_constant(cases, new_cases_lvls);
|
||||
// }
|
||||
// /* Update minor premises */
|
||||
// for (unsigned i = 0; i < nminors; i++) {
|
||||
// unsigned minor_idx = first_minor_idx + i;
|
||||
// expr minor = cases_args[minor_idx];
|
||||
// flet<local_ctx> save_lctx(m_lctx, m_lctx);
|
||||
// buffer<expr> minor_fvars;
|
||||
// unsigned old_fvars_size = m_fvars.size();
|
||||
// while (is_lambda(minor)) {
|
||||
// expr new_d = instantiate_rev(binding_domain(minor), minor_fvars.size(), minor_fvars.data());
|
||||
// expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(minor), new_d, binding_info(minor));
|
||||
// minor_fvars.push_back(new_fvar);
|
||||
// minor = binding_body(minor);
|
||||
// }
|
||||
// expr new_minor = visit(instantiate_rev(minor, minor_fvars.size(), minor_fvars.data()));
|
||||
// for (unsigned i = 0; i < args.size(); i++) {
|
||||
// expr new_minor_type = whnf_infer_type(new_minor);
|
||||
// lean_assert(is_pi(new_minor_type));
|
||||
// new_minor = mk_app(new_minor, mk_cast(arg_types[i], binding_domain(new_minor_type), args[i]));
|
||||
// }
|
||||
// new_minor = visit_let_value(new_minor);
|
||||
// type_checker tc(m_st, m_lctx);
|
||||
// new_minor = mk_cast(tc, tc.infer(new_minor), result_type, new_minor);
|
||||
// new_minor = mk_let(old_fvars_size, new_minor);
|
||||
// new_minor = m_lctx.mk_lambda(minor_fvars, new_minor);
|
||||
// cases_args[minor_idx] = new_minor;
|
||||
// }
|
||||
// return mk_let_decl(mk_app(cases, cases_args));
|
||||
// }
|
||||
|
||||
expr reduce_lc_cast(expr const & e) {
|
||||
buffer<expr> args;
|
||||
expr const & cast_fn1 = get_app_args(e, args);
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import init.lean.parser.parsec
|
|||
import init.control.coroutine
|
||||
universes u v w r s
|
||||
|
||||
set_option trace.compiler.lcnf true
|
||||
set_option trace.compiler.stage1 true
|
||||
-- set_option pp.implicit true
|
||||
set_option pp.binder_types false
|
||||
set_option pp.proofs true
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue