From 54a89dabb72f7fc4563c957e5f4e76c41f98bc9a Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 20 Feb 2019 16:12:58 -0800 Subject: [PATCH] feat(library/compiler/llnf): new reset/reuse insertion procedure --- src/library/compiler/llnf.cpp | 401 ++++++++++++++++++++++++++-------- src/library/compiler/llnf.h | 6 +- 2 files changed, 320 insertions(+), 87 deletions(-) diff --git a/src/library/compiler/llnf.cpp b/src/library/compiler/llnf.cpp index 26c53752d5..8624235e86 100644 --- a/src/library/compiler/llnf.cpp +++ b/src/library/compiler/llnf.cpp @@ -107,7 +107,7 @@ scalar fields. expr mk_llnf_cnstr(name const & I, unsigned cidx, unsigned num_usizes, unsigned num_bytes) { return mk_constant(name(name(name(name(I, g_cnstr), cidx), num_usizes), num_bytes)); } -bool is_llnf_cnstr_core(expr const & e, unsigned & cidx, unsigned & num_usizes, unsigned & num_bytes) { +bool is_llnf_cnstr(expr const & e, name & I, unsigned & cidx, unsigned & num_usizes, unsigned & num_bytes) { if (!is_constant(e)) return false; name const & n3 = const_name(e); if (!is_internal_name(n3)) return false; @@ -120,10 +120,9 @@ bool is_llnf_cnstr_core(expr const & e, unsigned & cidx, unsigned & num_usizes, if (n1.is_atomic() || !n1.is_numeral()) return false; cidx = n1.get_numeral().get_small_value(); name const & n0 = n1.get_prefix(); - return !n0.is_atomic() && n0.is_string() && n0.get_string() == g_cnstr; -} -bool is_llnf_cnstr(expr const & e, unsigned & cidx, unsigned & num_usizes, unsigned & num_bytes) { - return is_llnf_cnstr_core(e, cidx, num_usizes, num_bytes); + if (n0.is_atomic() || !n0.is_string() || n0.get_string() != g_cnstr) return false; + I = n0.get_prefix(); + return true; } /* The `_reuse....` is similar to `_cnstr...`, but it takes an extra argument: a memory cell that may be reused. */ @@ -286,6 +285,82 @@ static void get_borrowed_info(environment const & env, name const & n, buffer is_enum_type(environment const & env, expr const & type) { + expr const & I = get_app_fn(type); + if (!is_constant(I)) return optional(); + return is_enum_type(env, const_name(I)); +} + +static void get_cnstr_info_core(type_checker::state & st, bool unboxed, name const & n, buffer & result) { + environment const & env = st.env(); + constant_info info = env.get(n); + lean_assert(info.is_constructor()); + constructor_val val = info.to_constructor_val(); + expr type = info.get_type(); + name I_name = val.get_induct(); + unsigned nparams = val.get_nparams(); + local_ctx lctx; + buffer telescope; + unsigned next_object = 0; + unsigned next_usize = 0; + unsigned next_offset = 0; + to_telescope(env, lctx, st.ngen(), type, telescope); + lean_assert(telescope.size() >= nparams); + for (unsigned i = nparams; i < telescope.size(); i++) { + expr ftype = lctx.get_type(telescope[i]); + if (is_irrelevant_type(st, lctx, ftype)) { + result.push_back(field_info::mk_irrelevant()); + } else if (unboxed) { + type_checker tc(st, lctx); + ftype = tc.whnf(ftype); + if (is_usize_type(ftype)) { + result.push_back(field_info::mk_usize(next_usize)); + next_usize++; + } else if (optional sz = is_builtin_scalar(ftype)) { + result.push_back(field_info::mk_scalar(*sz, next_offset, ftype)); + next_offset += *sz; + } else if (optional sz = is_enum_type(env, ftype)) { + optional uint = to_uint_type(*sz); + if (!uint) throw exception("code generation failed, enumeration type is too big"); + result.push_back(field_info::mk_scalar(*sz, next_offset, *uint)); + next_offset += *sz; + } else { + result.push_back(field_info::mk_object(next_object)); + next_object++; + } + } else { + result.push_back(field_info::mk_object(next_object)); + next_object++; + } + } + unsigned nobjs = next_object; + unsigned nusizes = next_usize; + if (unboxed) { + /* Remark: + - usize fields are stored after object fields. + - regular scalar fields are stored after object and usize fields */ + for (field_info & info : result) { + switch (info.m_kind) { + case field_info::Scalar: + info.m_offset += (nobjs + nusizes) * sizeof(void*); + break; + case field_info::USize: + info.m_offset += nobjs * sizeof(void*); + break; + default: + break; + } + } + } +} + +static cnstr_info get_cnstr_info(type_checker::state & st, bool unboxed, name const & n) { + buffer finfos; + get_cnstr_info_core(st, unboxed, n, finfos); + unsigned cidx = get_constructor_idx(st.env(), n); + return cnstr_info(cidx, to_list(finfos)); +} + class to_llnf_fn { typedef name_hash_set name_set; typedef name_hash_map cnstr_info_cache; @@ -299,83 +374,13 @@ class to_llnf_fn { unsigned m_next_idx{1}; unsigned m_next_jp_idx{1}; cnstr_info_cache m_cnstr_info_cache; - enum_cache m_enum_cache; environment const & env() const { return m_st.env(); } name_generator & ngen() { return m_st.ngen(); } optional is_enum_type(expr const & type) { - expr const & I = get_app_fn(type); - if (!is_constant(I)) return optional(); - auto it = m_enum_cache.find(const_name(I)); - if (it != m_enum_cache.end()) - return it->second; - optional r = ::lean::is_enum_type(env(), const_name(I)); - m_enum_cache.insert(mk_pair(const_name(I), r)); - return r; - } - - void get_cnstr_info_core(name const & n, buffer & result) { - constant_info info = env().get(n); - lean_assert(info.is_constructor()); - constructor_val val = info.to_constructor_val(); - expr type = info.get_type(); - name I_name = val.get_induct(); - unsigned nparams = val.get_nparams(); - local_ctx lctx; - buffer telescope; - unsigned next_object = 0; - unsigned next_usize = 0; - unsigned next_offset = 0; - to_telescope(env(), lctx, ngen(), type, telescope); - lean_assert(telescope.size() >= nparams); - for (unsigned i = nparams; i < telescope.size(); i++) { - expr ftype = lctx.get_type(telescope[i]); - if (is_irrelevant_type(m_st, lctx, ftype)) { - result.push_back(field_info::mk_irrelevant()); - } else if (m_unboxed) { - type_checker tc(m_st, lctx); - ftype = tc.whnf(ftype); - if (is_usize_type(ftype)) { - result.push_back(field_info::mk_usize(next_usize)); - next_usize++; - } else if (optional sz = is_builtin_scalar(ftype)) { - result.push_back(field_info::mk_scalar(*sz, next_offset, ftype)); - next_offset += *sz; - } else if (optional sz = is_enum_type(ftype)) { - optional uint = to_uint_type(*sz); - if (!uint) throw exception("code generation failed, enumeration type is too big"); - result.push_back(field_info::mk_scalar(*sz, next_offset, *uint)); - next_offset += *sz; - } else { - result.push_back(field_info::mk_object(next_object)); - next_object++; - } - } else { - result.push_back(field_info::mk_object(next_object)); - next_object++; - } - } - unsigned nobjs = next_object; - unsigned nusizes = next_usize; - if (m_unboxed) { - /* Remark: - - usize fields are stored after object fields. - - regular scalar fields are stored after object and usize fields */ - for (field_info & info : result) { - switch (info.m_kind) { - case field_info::Scalar: - info.m_offset += (nobjs + nusizes) * sizeof(void*); - break; - case field_info::USize: - info.m_offset += nobjs * sizeof(void*); - break; - default: - break; - } - } - } + return ::lean::is_enum_type(env(), type); } unsigned get_arity(name const & n) const { @@ -419,10 +424,7 @@ class to_llnf_fn { auto it = m_cnstr_info_cache.find(n); if (it != m_cnstr_info_cache.end()) return it->second; - buffer finfos; - get_cnstr_info_core(n, finfos); - unsigned cidx = get_constructor_idx(env(), n); - cnstr_info r(cidx, to_list(finfos)); + cnstr_info r = ::lean::get_cnstr_info(m_st, m_unboxed, n); m_cnstr_info_cache.insert(mk_pair(n, r)); return r; } @@ -646,7 +648,7 @@ class to_llnf_fn { } j++; } - expr r = mk_app(mk_llnf_cnstr(cidx, k_info.m_num_usizes, k_info.m_scalar_sz), obj_args); + expr r = mk_app(mk_llnf_cnstr(I, cidx, k_info.m_num_usizes, k_info.m_scalar_sz), obj_args); j = nparams; unsigned offset = 0; unsigned uidx = 0; @@ -773,6 +775,7 @@ public: } }; +/* Push projections inside `cases_on` branches. */ class push_proj_fn { environment m_env; name_generator m_ngen; @@ -872,7 +875,7 @@ class push_proj_fn { flet save_lctx(m_lctx, m_lctx); buffer fvars; while (is_let(e)) { - expr val = instantiate_rev(let_value(e), fvars.size(), fvars.data()); + expr val = visit(instantiate_rev(let_value(e), fvars.size(), fvars.data())); expr new_fvar = m_lctx.mk_local_decl(m_ngen, let_name(e), let_type(e), val); fvars.push_back(new_fvar); e = let_body(e); @@ -900,11 +903,237 @@ public: }; class insert_reset_reuse_fn { + type_checker::state m_st; + bool m_unboxed; + local_ctx m_lctx; + name m_r{"_r"}; + unsigned m_next_idx{1}; + + struct opt_ctx { + name I; + cnstr_info cinfo; + expr x; + }; + + struct replace_ctx { + name I; + cnstr_info cinfo; + expr x; + expr reset_x; + replace_ctx(opt_ctx const & octx, expr const & rx): + I(octx.I), cinfo(octx.cinfo), x(octx.x), reset_x(rx) {} + }; + + environment const & env() { return m_st.env(); } + name_generator & ngen() { return m_st.ngen(); } + + expr replace_cnstr(replace_ctx const & rctx, expr const & e) { + buffer args; + expr const & c = get_app_args(e, args); + name I; unsigned cidx; unsigned nusizes; unsigned ssz; + lean_verify(is_llnf_cnstr(c, I, cidx, nusizes, ssz)); + if (I != rctx.I) { + /* Heuristic: we don't want to reuse cells from different types even when they are compatible + because it produces counterintuitive behavior. Here is an example: + ``` + @list.cases_on a + (@prod.cases_on a_1 (λ fst snd, (punit.star, snd))) + (λ a_hd a_tl, + @prod.cases_on a_1 + (λ fst snd, + let _x_1 := nat.add snd a_hd, + _x_2 := (punit.star, _x_1) + in list.mmap'._main._at.accum._spec_1 a_tl _x_2)) + ``` + Without this heuristic, we will try to construct `(punit.star, _x_1)` re-using `a` instead of `a_1`. */ + return e; + } + if (args.size() != rctx.cinfo.m_num_objs || + nusizes != rctx.cinfo.m_num_usizes || + ssz != rctx.cinfo.m_scalar_sz) { + /* This constructor is not compatible with major premise */ + return e; + } + expr r = mk_app(mk_llnf_reuse(cidx, nusizes, ssz, cidx != rctx.cinfo.m_cidx), rctx.reset_x); + return mk_app(r, args); + } + + expr replace_app(replace_ctx const & rctx, expr const & e) { + if (is_llnf_cnstr(get_app_fn(e))) { + return replace_cnstr(rctx, e); + } else if (is_cases_on_app(env(), e)) { + lean_assert(!m_replaced); + buffer args; + expr const & fn = get_app_args(e, args); + bool modified = false; + for (unsigned i = 1; i < args.size(); i++) { + expr new_arg = replace(rctx, args[i]); + if (new_arg != args[i]) { + modified = true; + args[i] = new_arg; + } + } + return modified ? mk_app(fn, args) : e; + } else { + return e; + } + } + + expr replace_let(replace_ctx const & rctx, expr const & e) { + expr new_value = replace(rctx, let_value(e)); + if (new_value != let_value(e)) { + return update_let(e, let_type(e), new_value, let_body(e)); + } else { + expr new_body = replace(rctx, let_body(e)); + return update_let(e, let_type(e), new_value, new_body); + } + } + + expr replace_lambda(replace_ctx const & rctx, expr const & e) { + expr new_body = replace(rctx, binding_body(e)); + return update_binding(e, binding_domain(e), new_body); + } + + expr replace(replace_ctx const & rctx, expr const & e) { + switch (e.kind()) { + case expr_kind::App: return replace_app(rctx, e); + case expr_kind::Let: return replace_let(rctx, e); + case expr_kind::Lambda: return replace_lambda(rctx, e); + default: return e; + } + } + + name next_reset_name() { + name r(m_r, m_next_idx); + m_next_idx++; + return r; + } + + expr replace(opt_ctx const & octx, expr const & e) { + expr reset_x = mk_fvar(ngen().next()); + replace_ctx rctx(octx, reset_x); + expr new_e = replace(rctx, e); + if (e == new_e) return e; + expr reset = mk_app(mk_llnf_reset(octx.cinfo.m_num_objs), octx.x); + return ::lean::mk_let(next_reset_name(), mk_enf_object_type(), reset, abstract(new_e, reset_x)); + } + + expr opt_let(opt_ctx const & octx, expr e) { + lean_assert(is_let(e)); + lean_assert(has_fvar(e, octx.x)); + flet save_lctx(m_lctx, m_lctx); + buffer fvars; + while (is_let(e)) { + expr val = instantiate_rev(let_value(e), fvars.size(), fvars.data()); + expr new_fvar = m_lctx.mk_local_decl(ngen(), let_name(e), let_type(e), val); + fvars.push_back(new_fvar); + if (has_fvar(let_value(e), octx.x) && !has_fvar(let_body(e), octx.x)) { + expr new_body = instantiate_rev(let_body(e), fvars.size(), fvars.data()); + new_body = replace(octx, new_body); + return m_lctx.mk_lambda(fvars, new_body); + } + e = let_body(e); + } + lean_assert(has_fvar(e, octx.x)); + expr new_body = opt(octx, instantiate_rev(e, fvars.size(), fvars.data())); + return m_lctx.mk_lambda(fvars, new_body); + } + + expr opt_cases(opt_ctx const & octx, expr const & e) { + buffer args; + expr const & fn = get_app_args(e, args); + bool modified = false; + for (unsigned i = 1; i < args.size(); i++) { + expr arg = args[i]; + expr new_arg = opt(octx, arg); + if (arg != new_arg) { + modified = true; + args[i] = new_arg; + } + } + return modified ? mk_app(fn, args) : e; + } + + expr opt(opt_ctx const & octx, expr const & e) { + if (!has_fvar(e, octx.x)) { + return replace(octx, e); + } else if (is_let(e)) { + return opt_let(octx, e); + } else if (is_cases_on_app(env(), e)) { + return opt_cases(octx, e); + } else { + return e; + } + } + + expr optimize_cases_on(expr const & e) { + lean_assert(is_cases_on_app(env(), e)); + buffer cases_args; + expr const & cases_fn = get_app_args(e, cases_args); + name const & I_name = const_name(cases_fn).get_prefix(); + expr const & major = cases_args[0]; + buffer cnames; + get_constructor_names(env(), I_name, cnames); + lean_assert(cases_args.size() == cnames.size() + 1); + for (unsigned i = 1; i < cases_args.size(); i++) { + cnstr_info cinfo = get_cnstr_info(m_st, m_unboxed, cnames[i-1]); + expr minor = optimize(cases_args[i]); + minor = opt(opt_ctx{I_name, cinfo, major}, minor); + cases_args[i] = minor; + } + return mk_app(cases_fn, cases_args); + } + + expr optimize_app(expr const & e) { + if (is_cases_on_app(env(), e)) + return optimize_cases_on(e); + else + return e; + } + + expr optimize_lambda(expr e) { + lean_assert(is_lambda(e)); + flet save_lctx(m_lctx, m_lctx); + buffer fvars; + while (is_lambda(e)) { + /* Types are ignored in compilation steps. So, we do not invoke visit for d. */ + expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), binding_domain(e)); + fvars.push_back(new_fvar); + e = binding_body(e); + } + expr new_body = optimize(instantiate_rev(e, fvars.size(), fvars.data())); + return m_lctx.mk_lambda(fvars, new_body); + } + + expr optimize_let(expr e) { + lean_assert(is_let(e)); + flet save_lctx(m_lctx, m_lctx); + buffer fvars; + while (is_let(e)) { + expr val = optimize(instantiate_rev(let_value(e), fvars.size(), fvars.data())); + expr new_fvar = m_lctx.mk_local_decl(ngen(), let_name(e), let_type(e), val); + fvars.push_back(new_fvar); + e = let_body(e); + } + expr new_body = optimize(instantiate_rev(e, fvars.size(), fvars.data())); + return m_lctx.mk_lambda(fvars, new_body); + } + + expr optimize(expr const & e) { + switch (e.kind()) { + case expr_kind::Lambda: return optimize_lambda(e); + case expr_kind::App: return optimize_app(e); + case expr_kind::Let: return optimize_let(e); + default: return e; + } + } + public: - insert_reset_reuse_fn(environment const & /* env */) {} + insert_reset_reuse_fn(environment const & env, bool unboxed): + m_st(env), m_unboxed(unboxed) {} + expr operator()(expr const & e) { - // TODO(Leo) - return e; + return optimize(e); } }; @@ -2066,7 +2295,7 @@ pair to_llnf(environment const & env, comp_decls const for (comp_decl const & d : ds) { expr new_v = to_llnf_fn(new_env, unboxed)(d.snd()); new_v = push_proj_fn(new_env)(new_v); - new_v = insert_reset_reuse_fn(new_env)(new_v); + new_v = insert_reset_reuse_fn(new_env, unboxed)(new_v); rs.push_back(comp_decl(d.fst(), new_v)); if (unboxed) { if (optional> p = mk_boxed_version(new_env, d.fst(), get_num_nested_lambdas(d.snd()))) { diff --git a/src/library/compiler/llnf.h b/src/library/compiler/llnf.h index ea3c21f72d..c6dbf70b8a 100644 --- a/src/library/compiler/llnf.h +++ b/src/library/compiler/llnf.h @@ -17,7 +17,11 @@ optional> mk_boxed_version(environment env, name co bool is_llnf_apply(expr const & e); bool is_llnf_closure(expr const & e); -bool is_llnf_cnstr(expr const & e, unsigned & cidx, unsigned & nusize, unsigned & ssz); +bool is_llnf_cnstr(expr const & e, name & I, unsigned & cidx, unsigned & nusize, unsigned & ssz); +inline bool is_llnf_cnstr(expr const & e, unsigned & cidx, unsigned & nusize, unsigned & ssz) { + name I; + return is_llnf_cnstr(e, I, cidx, nusize, ssz); +} bool is_llnf_reuse(expr const & e, unsigned & cidx, unsigned & nusize, unsigned & ssz, bool & updt_cidx); bool is_llnf_reset(expr const & e, unsigned & n); bool is_llnf_proj(expr const & e, unsigned & idx);