From 20e7edd4ac0c520bef8cb776e6dc0edbcb55f7a5 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 2 Oct 2018 11:05:15 -0700 Subject: [PATCH] feat(library/compiler): erase trivial structures and flat cases on structures --- src/library/compiler/erase_irrelevant.cpp | 306 +++++++++++++++++----- 1 file changed, 240 insertions(+), 66 deletions(-) diff --git a/src/library/compiler/erase_irrelevant.cpp b/src/library/compiler/erase_irrelevant.cpp index f6778bfc4c..5006fc55c8 100644 --- a/src/library/compiler/erase_irrelevant.cpp +++ b/src/library/compiler/erase_irrelevant.cpp @@ -5,6 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include "runtime/flet.h" +#include "kernel/kernel_exception.h" #include "kernel/instantiate.h" #include "kernel/abstract.h" #include "kernel/type_checker.h" @@ -12,44 +13,106 @@ Author: Leonardo de Moura namespace lean { class erase_irrelevant_fn { - type_checker::state m_st; - local_ctx m_lctx; + typedef std::tuple let_entry; + type_checker::state m_st; + local_ctx m_lctx; + buffer m_let_fvars; + buffer m_let_entries; + name_map> m_constructor_info; + name m_x; + unsigned m_next_idx{1}; environment & env() { return m_st.env(); } name_generator & ngen() { return m_st.ngen(); } - expr mk_runtime_type(expr e, bool atomic_only = false) { - type_checker tc(m_st, m_lctx); - e = tc.whnf(e); - if (is_constant(e)) { - name const & c = const_name(e); - if (is_runtime_scalar_type(c)) - return e; - else if (c == get_char_name()) - return mk_constant(get_uint32_name()); - else - return mk_enf_object_type(); - } else if (!atomic_only && is_app_of(e, get_array_name(), 1)) { - expr t = mk_runtime_type(app_arg(e), true); - return mk_app(app_fn(e), t); - } else if (is_sort(e)) { - return is_zero(sort_level(e)) ? mk_Prop() : mk_Type(); - } else if (tc.is_prop(e)) { - return mk_true(); + name next_name() { + name r = m_x.append_after(m_next_idx); + m_next_idx++; + return r; + } + + expr infer_type(expr const & e) { + try { + return type_checker(m_st, m_lctx).infer(e); + } catch (kernel_exception &) { + return mk_enf_object_type(); + } + } + + void get_constructor_info(name const & n, buffer & rel_fields) { + if (auto r = m_constructor_info.find(n)) { + to_buffer(*r, rel_fields); } else { + get_constructor_relevant_fields(env(), n, rel_fields); + m_constructor_info.insert(n, to_list(rel_fields)); + } + } + + /* Return (some idx) iff inductive datatype `I_name` has only one constructor, + and this constructor has only one relevant field, `idx` is the field position. */ + optional has_trivial_structure(name const & I_name) { + if (is_runtime_builtin_type(I_name)) + return optional(); + inductive_val I_val = env().get(I_name).to_inductive_val(); + if (I_val.get_ncnstrs() != 1) + return optional(); + buffer rel_fields; + get_constructor_info(head(I_val.get_cnstrs()), rel_fields); + /* The following #pragma is to disable a bogus g++ 4.9 warning at `optional r` */ + #if defined(__GNUC__) && !defined(__CLANG__) + #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" + #endif + optional result; + for (unsigned i = 0; i < rel_fields.size(); i++) { + if (rel_fields[i]) { + if (result) + return optional(); + result = i; + } + } + return result; + } + + expr mk_runtime_type(expr e, bool atomic_only = false) { + try { + type_checker tc(m_st, m_lctx); + e = tc.whnf(e); + if (is_constant(e)) { + name const & c = const_name(e); + if (is_runtime_scalar_type(c)) + return e; + else if (c == get_char_name()) + return mk_constant(get_uint32_name()); + else + return mk_enf_object_type(); + } else if (!atomic_only && is_app_of(e, get_array_name(), 1)) { + expr t = mk_runtime_type(app_arg(e), true); + return mk_app(app_fn(e), t); + } else if (is_sort(e)) { + return is_zero(sort_level(e)) ? mk_Prop() : mk_Type(); + } else if (tc.is_prop(e)) { + return mk_true(); + } else { + return mk_enf_object_type(); + } + } catch (kernel_exception &) { return mk_enf_object_type(); } } expr visit_constant(expr const & e) { lean_assert(!is_enf_neutral(e)); - type_checker tc(m_st, m_lctx); - expr e_type = tc.whnf(tc.infer(e)); - if (tc.is_prop(e_type) || is_sort(e_type)) - return mk_enf_neutral(); - else + try { + type_checker tc(m_st, m_lctx); + expr e_type = tc.whnf(tc.infer(e)); + if (tc.is_prop(e_type) || is_sort(e_type)) + return mk_enf_neutral(); + else + return mk_constant(const_name(e)); + } catch (kernel_exception &) { return mk_constant(const_name(e)); + } } bool is_atom(expr const & e) { @@ -61,15 +124,104 @@ class erase_irrelevant_fn { } } + expr visit_lambda_core(expr e, bool is_minor) { + flet save_lctx(m_lctx, m_lctx); + buffer bfvars; + buffer> entries; + while (is_lambda(e)) { + /* Types are ignored in compilation steps. So, we do not invoke visit for d. */ + expr d = instantiate_rev(binding_domain(e), bfvars.size(), bfvars.data()); + expr fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), d, binding_info(e)); + bfvars.push_back(fvar); + entries.emplace_back(binding_name(e), mk_runtime_type(d)); + e = binding_body(e); + } + unsigned saved_let_fvars_size = m_let_fvars.size(); + lean_assert(m_let_entries.size() == m_let_fvars.size()); + expr r = visit(instantiate_rev(e, bfvars.size(), bfvars.data())); + r = mk_let(saved_let_fvars_size, r); + if (is_minor && is_lambda(r)) { + /* Remark: we don't want to mix the lambda for minor premise fields, with the result. */ + r = ::lean::mk_let("_x", mk_enf_object_type(), r, mk_bvar(0)); + } + r = abstract(r, bfvars.size(), bfvars.data()); + unsigned i = entries.size(); + while (i > 0) { + --i; + r = mk_lambda(entries[i].first, entries[i].second, r); + } + return r; + } + + expr visit_lambda(expr const & e) { + return visit_lambda_core(e, false); + } + + expr visit_minor(expr const & e) { + return visit_lambda_core(e, true); + } + + /* Remark: we only keep major and minor premises. */ expr visit_cases_on(expr const & c, buffer & args) { + name const & I_name = const_name(c).get_prefix(); unsigned minors_begin; unsigned minors_end; std::tie(minors_begin, minors_end) = get_cases_on_minors_range(env(), const_name(c)); - for (unsigned i = 0; i < minors_begin - 1; i++) - args[i] = mk_enf_neutral(); - for (unsigned i = minors_begin - 1; i < minors_end; i++) { - args[i] = visit(args[i]); + if (!is_runtime_builtin_type(I_name) && minors_end == minors_begin + 1) { + expr major = visit(args[minors_begin - 1]); + lean_assert(is_atom(major)); + expr minor = args[minors_begin]; + if (optional fidx = has_trivial_structure(const_name(c).get_prefix())) { + lean_assert(minors_begin + 1 == minors_end); + unsigned i = 0; + buffer fields; + while (is_lambda(minor)) { + if (i == *fidx) { + fields.push_back(major); + } else { + fields.push_back(mk_enf_neutral()); + } + i++; + minor = binding_body(minor); + } + expr r = instantiate_rev(minor, fields.size(), fields.data()); + return visit(r); + } else { + /* + ``` + prod.cases_on M (\fun a b, t) + ``` + ==> + ``` + let a := M.0 in + let b := M.1 in + t + */ + unsigned i = 0; + buffer fields; + while (is_lambda(minor)) { + expr v = mk_proj(I_name, i, major); + expr t = infer_type(v); + name n = next_name(); + expr fvar = m_lctx.mk_local_decl(ngen(), n, t, v); + fields.push_back(fvar); + expr new_t = mk_runtime_type(t); + expr new_v = visit(v); + m_let_fvars.push_back(fvar); + m_let_entries.emplace_back(n, new_t, new_v); + i++; + minor = binding_body(minor); + } + expr r = instantiate_rev(minor, fields.size(), fields.data()); + return visit(r); + } + } else { + buffer new_args; + new_args.push_back(visit(args[minors_begin - 1])); + for (unsigned i = minors_begin; i < minors_end; i++) { + new_args.push_back(visit_minor(args[i])); + } + return mk_app(c, new_args); } - return mk_app(c, args); } expr visit_app_default(expr const & fn, buffer & args) { @@ -98,6 +250,18 @@ class erase_irrelevant_fn { return visit(args[2]); } + expr visit_constructor(expr const & fn, buffer & args) { + constructor_val c_val = env().get(const_name(fn)).to_constructor_val(); + name const & I_name = c_val.get_induct(); + if (optional fidx = has_trivial_structure(I_name)) { + unsigned nparams = c_val.get_nparams(); + lean_assert(nparams + *fidx < args.size()); + return args[nparams + *fidx]; + } else { + return visit_app_default(fn, args); + } + } + expr visit_app(expr const & e) { buffer args; expr f = visit(get_app_args(e, args)); @@ -105,6 +269,8 @@ class erase_irrelevant_fn { name const & fn = const_name(f); if (fn == get_lc_proof_name()) { return mk_enf_neutral(); + } else if (is_constructor(env(), fn)) { + return visit_constructor(f, args); } else if (is_cases_on_recursor(env(), fn)) { return visit_cases_on(f, args); } else if (fn == get_quot_mk_name()) { @@ -116,52 +282,58 @@ class erase_irrelevant_fn { return visit_app_default(f, args); } - expr visit_lambda(expr e) { - flet save_lctx(m_lctx, m_lctx); - buffer fvars; - buffer> entries; - while (is_lambda(e)) { - /* Types are ignored in compilation steps. So, we do not invoke visit for d. */ - expr d = instantiate_rev(binding_domain(e), fvars.size(), fvars.data()); - expr fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), d, binding_info(e)); - fvars.push_back(fvar); - entries.emplace_back(binding_name(e), mk_runtime_type(d)); - e = binding_body(e); + expr visit_proj(expr const & e) { + if (optional fidx = has_trivial_structure(proj_sname(e))) { + if (*fidx != proj_idx(e).get_small_value()) + return mk_enf_neutral(); + else + return proj_expr(e); + } else { + return e; } - expr r = visit(instantiate_rev(e, fvars.size(), fvars.data())); - r = abstract(r, fvars.size(), fvars.data()); - unsigned i = entries.size(); - while (i > 0) { + } + + expr mk_let(unsigned saved_fvars_size, expr r) { + lean_assert(saved_fvars_size <= m_let_fvars.size()); + lean_assert(m_let_fvars.size() == m_let_entries.size()); + if (saved_fvars_size == m_let_fvars.size()) + return r; + r = abstract(r, m_let_fvars.size() - saved_fvars_size, m_let_fvars.data() + saved_fvars_size); + unsigned i = m_let_fvars.size(); + while (i > saved_fvars_size) { --i; - r = mk_lambda(entries[i].first, entries[i].second, r); + expr v = abstract(std::get<2>(m_let_entries[i]), i - saved_fvars_size, m_let_fvars.data() + saved_fvars_size); + r = ::lean::mk_let(std::get<0>(m_let_entries[i]), std::get<1>(m_let_entries[i]), v, r); } + m_let_fvars.shrink(saved_fvars_size); + m_let_entries.shrink(saved_fvars_size); return r; } expr visit_let(expr e) { - flet save_lctx(m_lctx, m_lctx); - buffer fvars; - buffer> entries; + lean_assert(m_let_entries.size() == m_let_fvars.size()); + buffer curr_fvars; while (is_let(e)) { - expr t = instantiate_rev(let_type(e), fvars.size(), fvars.data()); - expr v = instantiate_rev(let_value(e), fvars.size(), fvars.data()); - expr fvar = m_lctx.mk_local_decl(ngen(), let_name(e), t, v); - fvars.push_back(fvar); - entries.emplace_back(let_name(e), mk_runtime_type(t), visit(v)); + expr t = instantiate_rev(let_type(e), curr_fvars.size(), curr_fvars.data()); + expr v = instantiate_rev(let_value(e), curr_fvars.size(), curr_fvars.data()); + name n = let_name(e); + if (is_internal_name(n) && !is_join_point_name(n)) { + n = next_name(); + } + expr fvar = m_lctx.mk_local_decl(ngen(), n, t, v); + curr_fvars.push_back(fvar); + expr new_t = mk_runtime_type(t); + expr new_v = visit(v); + m_let_fvars.push_back(fvar); + m_let_entries.emplace_back(n, new_t, new_v); e = let_body(e); } - expr r = visit(instantiate_rev(e, fvars.size(), fvars.data())); - r = abstract(r, fvars.size(), fvars.data()); - unsigned i = entries.size(); - while (i > 0) { - --i; - expr v = abstract(std::get<2>(entries[i]), i, fvars.data()); - r = mk_let(std::get<0>(entries[i]), std::get<1>(entries[i]), v, r); - } - return r; + lean_assert(m_let_entries.size() == m_let_fvars.size()); + return visit(instantiate_rev(e, curr_fvars.size(), curr_fvars.data())); } expr visit(expr const & e) { + lean_assert(m_let_entries.size() == m_let_fvars.size()); switch (e.kind()) { case expr_kind::BVar: case expr_kind::MVar: lean_unreachable(); @@ -171,7 +343,7 @@ class erase_irrelevant_fn { case expr_kind::Pi: return mk_enf_neutral(); case expr_kind::Const: return visit_constant(e); case expr_kind::App: return visit_app(e); - case expr_kind::Proj: return e; + case expr_kind::Proj: return visit_proj(e); case expr_kind::MData: return e; case expr_kind::Lambda: return visit_lambda(e); case expr_kind::Let: return visit_let(e); @@ -180,8 +352,10 @@ class erase_irrelevant_fn { } public: erase_irrelevant_fn(environment const & env, local_ctx const & lctx): - m_st(env), m_lctx(lctx) {} - expr operator()(expr const & e) { return visit(e); } + m_st(env), m_lctx(lctx), m_x("_x") {} + expr operator()(expr const & e) { + return mk_let(0, visit(e)); + } }; expr erase_irrelevant(environment const & env, local_ctx const & lctx, expr const & e) {