feat(library/compiler/llnf): add process_cases

This commit is contained in:
Leonardo de Moura 2019-01-17 13:29:38 -08:00
parent c1534fd476
commit 244079a4cc

View file

@ -1456,6 +1456,16 @@ class explicit_rc_fn {
collect_live_vars_core(e, visited_jp, r);
}
void collect_rhs_live_vars(expr const & e, name_set & r) {
lean_assert(is_app(e) || is_constant(e) || is_lit(e));
buffer<expr> args;
get_app_args(e, args);
for (expr const & arg : args) {
if (is_fvar(arg))
r.insert(fvar_name(arg));
}
}
expr get_value_of(expr const & x) const {
lean_assert(is_fvar(x));
return *m_lctx.get_local_decl(x).get_value();
@ -1696,11 +1706,34 @@ class explicit_rc_fn {
}
}
expr process_cases(expr const & e) {
name_set cases_live_vars;
collect_live_vars(e, cases_live_vars);
buffer<expr> args;
expr const & fn = get_app_args(e, args);
for (unsigned i = 1; i < args.size(); i++) {
expr arg = args[i]; /* A "case/branch" of the `cases_on` term. */
name_set arg_live_vars;
collect_live_vars(arg, arg_live_vars);
unsigned saved_fvars_size = m_fvars.size();
arg = visit(arg);
arg = mk_let(saved_fvars_size, arg);
/* We must decrement variables that are live at `cases_live_vars`, but are not live at `arg_live_vars`. */
cases_live_vars.for_each([&](name const & x_name) {
if (!arg_live_vars.contains(x_name)) {
expr x = m_lctx.get_local_decl(x_name).mk_ref();
arg = ::lean::mk_let(next_name(), mk_void_type(), mk_dec(x), arg);
}
});
args[i] = arg;
}
return mk_app(fn, args);
}
/* Process a terminal: cases, jmp or variable */
expr process_terminal(expr const & e, buffer<expr_pair> & entries) {
if (is_cases_on_app(env(), e)) {
// TODO(Leo)
return e;
return process_cases(e);
} else if (is_jmp(e)) {
add_incs_for_jmp_args(e, entries);
return e;
@ -1729,6 +1762,8 @@ class explicit_rc_fn {
local_decl x_decl = m_lctx.get_local_decl(x);
if (!is_join_point_name(x_decl.get_user_name())) {
process(x, entries, live_obj_vars);
collect_rhs_live_vars(*x_decl.get_value(), live_obj_vars);
live_obj_vars.erase(fvar_name(x));
} else {
expr jp_val = visit_jp_lambda(*x_decl.get_value());
entries.emplace_back(x, jp_val);
@ -1916,7 +1951,7 @@ pair<environment, comp_decls> to_llnf(environment const & env, comp_decls const
expr new_code = explicit_boxing_fn(new_env)(r.fst(), r.snd());
new_code = ecse(new_env, new_code);
new_code = elim_dead_let(new_code);
// new_code = explicit_rc_fn(new_env)(r.fst(), new_code);
new_code = explicit_rc_fn(new_env)(r.fst(), new_code);
r = comp_decl(r.fst(), new_code);
}
}