diff --git a/src/library/compiler/llnf.cpp b/src/library/compiler/llnf.cpp index 9aad8e2c32..a95acb61ed 100644 --- a/src/library/compiler/llnf.cpp +++ b/src/library/compiler/llnf.cpp @@ -1456,6 +1456,16 @@ class explicit_rc_fn { collect_live_vars_core(e, visited_jp, r); } + void collect_rhs_live_vars(expr const & e, name_set & r) { + lean_assert(is_app(e) || is_constant(e) || is_lit(e)); + buffer args; + get_app_args(e, args); + for (expr const & arg : args) { + if (is_fvar(arg)) + r.insert(fvar_name(arg)); + } + } + expr get_value_of(expr const & x) const { lean_assert(is_fvar(x)); return *m_lctx.get_local_decl(x).get_value(); @@ -1696,11 +1706,34 @@ class explicit_rc_fn { } } + expr process_cases(expr const & e) { + name_set cases_live_vars; + collect_live_vars(e, cases_live_vars); + buffer args; + expr const & fn = get_app_args(e, args); + for (unsigned i = 1; i < args.size(); i++) { + expr arg = args[i]; /* A "case/branch" of the `cases_on` term. */ + name_set arg_live_vars; + collect_live_vars(arg, arg_live_vars); + unsigned saved_fvars_size = m_fvars.size(); + arg = visit(arg); + arg = mk_let(saved_fvars_size, arg); + /* We must decrement variables that are live at `cases_live_vars`, but are not live at `arg_live_vars`. */ + cases_live_vars.for_each([&](name const & x_name) { + if (!arg_live_vars.contains(x_name)) { + expr x = m_lctx.get_local_decl(x_name).mk_ref(); + arg = ::lean::mk_let(next_name(), mk_void_type(), mk_dec(x), arg); + } + }); + args[i] = arg; + } + return mk_app(fn, args); + } + /* Process a terminal: cases, jmp or variable */ expr process_terminal(expr const & e, buffer & entries) { if (is_cases_on_app(env(), e)) { - // TODO(Leo) - return e; + return process_cases(e); } else if (is_jmp(e)) { add_incs_for_jmp_args(e, entries); return e; @@ -1729,6 +1762,8 @@ class explicit_rc_fn { local_decl x_decl = m_lctx.get_local_decl(x); if (!is_join_point_name(x_decl.get_user_name())) { process(x, entries, live_obj_vars); + collect_rhs_live_vars(*x_decl.get_value(), live_obj_vars); + live_obj_vars.erase(fvar_name(x)); } else { expr jp_val = visit_jp_lambda(*x_decl.get_value()); entries.emplace_back(x, jp_val); @@ -1916,7 +1951,7 @@ pair to_llnf(environment const & env, comp_decls const expr new_code = explicit_boxing_fn(new_env)(r.fst(), r.snd()); new_code = ecse(new_env, new_code); new_code = elim_dead_let(new_code); - // new_code = explicit_rc_fn(new_env)(r.fst(), new_code); + new_code = explicit_rc_fn(new_env)(r.fst(), new_code); r = comp_decl(r.fst(), new_code); } }