feat(library/compiler/csimp): preserve joint points

This commit is contained in:
Leonardo de Moura 2018-09-26 17:30:38 -07:00
parent ae81ac2768
commit ccd73701cb

View file

@ -61,6 +61,14 @@ class csimp_fn {
return m_simplified.find(e) != m_simplified.end();
}
bool is_join_point_app(expr const & e) const {
if (!is_app(e)) return false;
expr const & fn = get_app_fn(e);
return
is_fvar(fn) &&
is_join_point_name(m_lctx.get_local_decl(fn).get_user_name());
}
/* Very simple predicate used to decide whether we should inline joint-points or not.
TODO(Leo): improve */
bool is_small(expr const & e) const {
@ -191,14 +199,14 @@ class csimp_fn {
return m_lctx.mk_lambda(xs, e);
}
expr get_minor_body(expr minor, buffer<expr> & xs) {
while (is_lambda(minor)) {
expr d = instantiate_rev(binding_domain(minor), xs.size(), xs.data());
expr x = m_lctx.mk_local_decl(ngen(), binding_name(minor), d, binding_info(minor));
expr get_lambda_body(expr e, buffer<expr> & xs) {
while (is_lambda(e)) {
expr d = instantiate_rev(binding_domain(e), xs.size(), xs.data());
expr x = m_lctx.mk_local_decl(ngen(), binding_name(e), d, binding_info(e));
xs.push_back(x);
minor = binding_body(minor);
e = binding_body(e);
}
return instantiate_rev(minor, xs.size(), xs.data());
return instantiate_rev(e, xs.size(), xs.data());
}
/* Move let-decl `fvar` to the minor premise at position `minor_idx` of cases_on-application `c`. */
@ -209,7 +217,7 @@ class csimp_fn {
expr minor = args[minor_idx];
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> xs;
minor = get_minor_body(minor, xs);
minor = get_lambda_body(minor, xs);
if (minor == fvar) {
/* `let x := v in x` ==> `v` */
minor = *m_lctx.get_local_decl(fvar).get_value();
@ -233,7 +241,7 @@ class csimp_fn {
The term `e` is "copied" in each branch of the the `cases_on` expression `c`.
This method creates one (or more) join-point(s) for `e` (if needed).
This method main fail because of dependent types. */
optional<expr> mk_join_point_float_cases_on(expr const & fvar, expr const & c, expr const & e) {
optional<expr> mk_join_point_float_cases_on(expr const & fvar, expr const & e, expr const & c, buffer<expr> & new_jps) {
lean_assert(is_cases_on_app(env(), c));
expr const & c_fn = get_app_fn(c);
inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val();
@ -254,17 +262,17 @@ class csimp_fn {
for (unsigned minor_idx = begin_minors; minor_idx < end_minors; minor_idx++) {
expr minor = args[minor_idx];
if (get_lcnf_size(env(), minor) > m_cfg.m_float_cases_jp_branch_threshold) {
buffer<bool> used_xs; /* used_xs[i] iff `minor` uses `xs[i]` */
buffer<bool> used_zs; /* used_zs[i] iff `minor` uses `zs[i]` */
bool used_fvar = false; /* true iff `minor` uses `fvar` */
bool used_unit = false; /* true if we needed to add `unit ->` to joint point */
expr jp_val;
/* Create join-point value: `jp-val` */
{
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> xs;
minor = get_minor_body(minor, xs);
mark_used_fvars(minor, xs, used_xs);
lean_assert(xs.size() == used_xs.size());
buffer<expr> zs;
minor = get_lambda_body(minor, zs);
mark_used_fvars(minor, zs, used_zs);
lean_assert(zs.size() == used_zs.size());
used_fvar = false;
jp_val = minor;
buffer<expr> jp_args;
@ -282,9 +290,9 @@ class csimp_fn {
return none_expr();
}
}
for (unsigned i = 0; i < used_xs.size(); i++) {
if (used_xs[i])
jp_args.push_back(xs[i]);
for (unsigned i = 0; i < used_zs.size(); i++) {
if (used_zs[i])
jp_args.push_back(zs[i]);
}
if (jp_args.empty()) {
jp_args.push_back(m_lctx.mk_local_decl(ngen(), "_", mk_unit()));
@ -296,24 +304,24 @@ class csimp_fn {
expr jp_type = cheap_beta_reduce(infer_type(jp_val));
mark_simplified(jp_val);
expr jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_type, jp_val);
m_fvars.push_back(jp_var);
new_jps.push_back(jp_var);
/* Replace minor with new jp */
{
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> xs;
buffer<expr> zs;
minor = args[minor_idx];
minor = get_minor_body(minor, xs);
lean_assert(xs.size() == used_xs.size());
minor = get_lambda_body(minor, zs);
lean_assert(zs.size() == used_zs.size());
expr new_minor = jp_var;
if (used_unit)
new_minor = mk_app(new_minor, mk_unit_mk());
if (used_fvar)
new_minor = mk_app(new_minor, fvar);
for (unsigned i = 0; i < used_xs.size(); i++) {
if (used_xs[i])
new_minor = mk_app(new_minor, xs[i]);
for (unsigned i = 0; i < used_zs.size(); i++) {
if (used_zs[i])
new_minor = mk_app(new_minor, zs[i]);
}
new_minor = mk_minor_lambda(xs, new_minor);
new_minor = mk_minor_lambda(zs, new_minor);
args[minor_idx] = new_minor;
modified = true;
}
@ -350,7 +358,7 @@ class csimp_fn {
expr jp_type = cheap_beta_reduce(infer_type(jp_val));
mark_simplified(jp_val);
expr jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_type, jp_val);
m_fvars.push_back(jp_var);
new_jps.push_back(jp_var);
lean_trace(name({"compiler", "simp"}),
tout() << "mk_join " << fvar << "\n" << c << "\n---\n"
<< e << "\n======>\n" << mk_app(jp_var, fvar) << "\n";);
@ -358,24 +366,104 @@ class csimp_fn {
}
}
/* Float cases transformation
/* Given `e[x]`, create a let-decl `y := v`, and return `e[y]`
Casts are introduced if necessary. The result is `none` if it fails to produce type correct `e[y]`. */
optional<expr> apply_at(expr const & x, expr const & e, expr const & v) {
local_decl x_decl = m_lctx.get_local_decl(x);
expr v_type = infer_type(v);
expr new_v = mk_cast(v_type, x_decl.get_type(), v);
expr y = m_lctx.mk_local_decl(ngen(), x_decl.get_user_name(), x_decl.get_type(), new_v);
optional<expr> e_y_opt = replace_fvar_with(m_st, m_lctx, e, x, y);
if (!e_y_opt) return none_expr(); /* Failed to produce type correct `e[y]` */
expr e_y = *e_y_opt;
m_fvars.push_back(y);
return some_expr(visit(e_y, false));
}
/*
Given `e[x]`
```
let x := cases_on m
let jp := fun z, let .... in e'
```
==>
```
let jp' := fun z, let ... y := e' in e[y]
```
If `e'` is a `cases_on` application, we use `float_cases_on_core`. That is,
```
let jp := fun z, let ... in
cases_on m
(fun y_1, let ... in e_1)
...
(fun y_n, let ... in e_n)
in e
```
==>
```
let jp := fun z, let ... in
cases_on m
(fun y_1, let ... y := e_1 in e[y])
...
(fun y_n, let ... y := e_n in e[y])
```
Remark: this method return `none` if the new join point cannot be created
due to type errors. */
optional<expr> mk_new_join_point(expr const & x, expr const & e, expr const & jp, buffer<expr> & new_jps, expr_map<expr> & new_jp_cache) {
auto it = new_jp_cache.find(jp);
if (it != new_jp_cache.end())
return some_expr(it->second);
local_decl jp_decl = m_lctx.get_local_decl(jp);
lean_assert(is_join_point_name(jp_decl.get_user_name()));
expr jp_val = *jp_decl.get_value();
buffer<expr> zs;
unsigned saved_fvars_size = m_fvars.size();
jp_val = visit(get_lambda_body(jp_val, zs), false);
expr new_jp_val;
expr e_y;
if (is_join_point_app(jp_val)) {
buffer<expr> jp2_args;
expr const & jp2 = get_app_args(jp_val, jp2_args);
optional<expr> new_jp2_opt = mk_new_join_point(x, e, jp2, new_jps, new_jp_cache);
if (!new_jp2_opt) return none_expr();
e_y = mk_app(*new_jp2_opt, jp2_args);
} else if (is_cases_on_app(env(), jp_val)) {
optional<expr> e_y_opt = float_cases_on_core(x, e, jp_val, new_jps, new_jp_cache);
if (!e_y_opt) return none_expr();
e_y = *e_y_opt;
} else {
optional<expr> e_y_opt = apply_at(x, e, jp_val);
if (!e_y_opt) return none_expr();
e_y = *e_y_opt;
}
expr e_y_type = infer_type(e_y);
expr jp_val_type = infer_type(jp_val);
new_jp_val = mk_cast(e_y_type, jp_val_type, e_y);
new_jp_val = mk_let(saved_fvars_size, new_jp_val);
new_jp_val = m_lctx.mk_lambda(zs, new_jp_val);
mark_simplified(new_jp_val);
expr new_jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_decl.get_type(), new_jp_val);
new_jps.push_back(new_jp_var);
new_jp_cache.insert(mk_pair(jp, new_jp_var));
return some_expr(new_jp_var);
}
/* Given `e[x]`
```
cases_on m
(fun zs, let ... in e_1)
...
(fun zs, let ... in e_n)
```
==>
```
cases_on m
(fun y_1, let ... x := e_1 in e)
(fun zs, let ... y := e_1 in e[y])
...
(fun y_n, let ... x := e_n in e)
(fun y_n, let ... y := e_n in e[y])
``` */
optional<expr> float_cases_on_core(expr const & fvar, expr const & c, expr e) {
optional<expr> float_cases_on_core(expr const & x, expr const & e, expr const & c, buffer<expr> & new_jps, expr_map<expr> & new_jp_cache) {
lean_assert(is_cases_on_app(env(), c));
local_decl fvar_decl = m_lctx.get_local_decl(fvar);
local_decl x_decl = m_lctx.get_local_decl(x);
expr result_type = whnf_infer_type(e);
buffer<expr> c_args;
expr c_fn = get_app_args(c, c_args);
@ -389,17 +477,17 @@ class csimp_fn {
/* Update motive */
{
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> fvars;
buffer<expr> zs;
expr motive = c_args[motive_idx];
expr motive_type = whnf_infer_type(motive);
for (unsigned i = 0; i < nindices + 1; i++) {
lean_assert(is_pi(motive_type));
expr fvar = m_lctx.mk_local_decl(ngen(), binding_name(motive_type), binding_domain(motive_type), binding_info(motive_type));
fvars.push_back(fvar);
motive_type = whnf(instantiate(binding_body(motive_type), fvar));
expr z = m_lctx.mk_local_decl(ngen(), binding_name(motive_type), binding_domain(motive_type), binding_info(motive_type));
zs.push_back(z);
motive_type = whnf(instantiate(binding_body(motive_type), z));
}
level result_lvl = sort_level(tc().ensure_type(result_type));
expr new_motive = m_lctx.mk_lambda(fvars, result_type);
expr new_motive = m_lctx.mk_lambda(zs, result_type);
c_args[motive_idx] = new_motive;
/* We need to update the resultant universe. */
levels new_cases_lvls = levels(result_lvl, tail(const_levels(c_fn)));
@ -407,26 +495,29 @@ class csimp_fn {
}
/* Update minor premises */
for (unsigned i = 0; i < nminors; i++) {
unsigned minor_idx = first_minor_idx + i;
expr minor = c_args[minor_idx];
buffer<expr> xs;
flet<local_ctx> save_lctx(m_lctx, m_lctx);
unsigned old_fvars_size = m_fvars.size();
expr minor_val = visit(get_minor_body(minor, xs), false);
/* TODO(Leo): We need to preserve join points. */
expr minor_val_type = infer_type(minor_val);
minor_val = mk_cast(minor_val_type, fvar_decl.get_type(), minor_val);
expr new_fvar = m_lctx.mk_local_decl(ngen(), fvar_decl.get_user_name(), fvar_decl.get_type(), minor_val);
optional<expr> new_minor_opt = replace_fvar_with(m_st, m_lctx, e, fvar, new_fvar);
if (!new_minor_opt) return none_expr(); /* Failed to produce type correct `new_minor` */
expr new_minor = *new_minor_opt;
m_fvars.push_back(new_fvar);
new_minor = visit(new_minor, false);
expr new_minor_type = infer_type(new_minor);
new_minor = mk_cast(new_minor_type, result_type, new_minor);
new_minor = mk_let(old_fvars_size, new_minor);
new_minor = mk_minor_lambda(xs, new_minor);
c_args[minor_idx] = new_minor;
unsigned minor_idx = first_minor_idx + i;
expr minor = c_args[minor_idx];
buffer<expr> zs;
unsigned saved_fvars_size = m_fvars.size();
expr minor_val = visit(get_lambda_body(minor, zs), false);
expr new_minor;
if (is_join_point_app(minor_val)) {
buffer<expr> jp_args;
expr const & jp = get_app_args(minor_val, jp_args);
optional<expr> new_jp_opt = mk_new_join_point(x, e, jp, new_jps, new_jp_cache);
if (!new_jp_opt) return none_expr();
expr new_jp = *new_jp_opt;
new_minor = mk_app(new_jp, jp_args);
} else {
optional<expr> e_y_opt = apply_at(x, e, minor_val);
if (!e_y_opt) return none_expr();
expr e_y = *e_y_opt;
expr e_y_type = infer_type(e_y);
new_minor = mk_cast(e_y_type, result_type, e_y);
}
new_minor = mk_let(saved_fvars_size, new_minor);
new_minor = mk_minor_lambda(zs, new_minor);
c_args[minor_idx] = new_minor;
}
lean_trace(name({"compiler", "simp"}),
tout() << "float_cases_on [" << get_lcnf_size(env(), e) << "]\n" << c << "\n----\n" << e << "\n=====>\n"
@ -436,11 +527,12 @@ class csimp_fn {
/* Float cases transformation (see: `float_cases_on_core`).
This version may create join points if `e` is big, or "good" join-points could not be created. */
optional<expr> float_cases_on(expr const & fvar, expr const & c, expr const & e) {
optional<expr> float_cases_on(expr const & x, expr const & e, expr const & c, buffer<expr> & new_jps) {
expr_map<expr> new_jp_cache;
unsigned saved_fvars_size = m_fvars.size();
local_ctx saved_lctx = m_lctx;
if (optional<expr> new_e = mk_join_point_float_cases_on(fvar, c, e)) {
if (optional<expr> r = float_cases_on_core(fvar, c, *new_e)) {
if (optional<expr> new_e = mk_join_point_float_cases_on(x, e, c, new_jps)) {
if (optional<expr> r = float_cases_on_core(x, *new_e, c, new_jps, new_jp_cache)) {
return r;
}
}
@ -591,23 +683,23 @@ class csimp_fn {
if (is_cases_on_app(env(), val)) {
/* Float cases transformation. */
if (m_cfg.m_float_cases) {
unsigned m_fvars_init_size = m_fvars.size();
/* We first create a let-declaration with all entries that depends on the current
`fvar` which is a cases_on application. */
buffer<pair<expr, expr>> entries_dep_curr;
buffer<pair<expr, expr>> entries_ndep_curr;
split_entries(entries, fvar, entries_dep_curr, entries_ndep_curr);
expr new_e = mk_let(entries_dep_curr, e);
if (optional<expr> new_e_opt = float_cases_on(fvar, val, new_e)) {
buffer<expr> new_jps;
if (optional<expr> new_e_opt = float_cases_on(fvar, new_e, val, new_jps)) {
e = *new_e_opt;
/* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */
e_fvars.clear(); entries_fvars.clear();
collect_used(e, e_fvars);
/* Join points may have been generated, we move them to entries. */
entries.clear();
while (m_fvars.size() > m_fvars_init_size) {
expr jp_fvar = m_fvars.back();
m_fvars.pop_back();
while (!new_jps.empty()) {
expr jp_fvar = new_jps.back();
new_jps.pop_back();
local_decl jp_decl = m_lctx.get_local_decl(jp_fvar);
expr jp_type = jp_decl.get_type();
expr jp_val = *jp_decl.get_value();
@ -662,8 +754,12 @@ class csimp_fn {
let_fvars.push_back(new_val);
} else {
name n = let_name(e);
if (is_internal_name(n) && !is_join_point_name(n))
n = next_name();
if (is_internal_name(n)) {
if (is_join_point_name(n))
n = next_jp_name();
else
n = next_name();
}
expr new_fvar = m_lctx.mk_local_decl(ngen(), n, new_type, new_val);
let_fvars.push_back(new_fvar);
m_fvars.push_back(new_fvar);
@ -801,11 +897,11 @@ class csimp_fn {
expr minor = args[minor_idx];
unsigned saved_fvars_size = m_fvars.size();
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> xs;
minor = get_minor_body(minor, xs);
buffer<expr> zs;
minor = get_lambda_body(minor, zs);
expr new_minor = visit(minor, false);
new_minor = mk_let(saved_fvars_size, new_minor);
new_minor = mk_minor_lambda(xs, new_minor);
new_minor = mk_minor_lambda(zs, new_minor);
args[minor_idx] = new_minor;
}
expr r = mk_app(c, args);
@ -943,9 +1039,12 @@ class csimp_fn {
return beta_reduce(fn, e, is_let_val);
} else if (is_cases_on_app(env(), fn) && m_cfg.m_float_cases_app) {
lean_assert(is_fvar(get_app_fn(e)));
buffer<expr> new_jps;
/* float cases_on from application */
if (optional<expr> new_e = float_cases_on_core(get_app_fn(e), fn, e)) {
expr_map<expr> new_jp_cache;
if (optional<expr> new_e = float_cases_on_core(get_app_fn(e), e, fn, new_jps, new_jp_cache)) {
mark_simplified(*new_e);
m_fvars.append(new_jps);
return *new_e;
} else {
mark_simplified(e);