diff --git a/src/library/compiler/erase_irrelevant.cpp b/src/library/compiler/erase_irrelevant.cpp index 48069e35e2..da5a23c9a1 100644 --- a/src/library/compiler/erase_irrelevant.cpp +++ b/src/library/compiler/erase_irrelevant.cpp @@ -21,6 +21,7 @@ class erase_irrelevant_fn { name_map> m_constructor_info; name m_x; unsigned m_next_idx{1}; + expr_map m_irrelevant_cache; environment & env() { return m_st.env(); } @@ -97,25 +98,37 @@ class erase_irrelevant_fn { } } + bool cache_is_irrelevant(expr const & e, bool r) { + if (is_constant(e) || is_fvar(e)) + m_irrelevant_cache.insert(mk_pair(e, r)); + return r; + } + bool is_irrelevant(expr const & e) { + if (is_constant(e) || is_fvar(e)) { + auto it1 = m_irrelevant_cache.find(e); + if (it1 != m_irrelevant_cache.end()) + return it1->second; + } try { type_checker tc(m_st, m_lctx); expr type = tc.whnf(tc.infer(e)); if (is_sort(type) || tc.is_prop(type)) - return true; - if (is_pi(type)) { + return cache_is_irrelevant(e, true); + expr type_it = type; + if (is_pi(type_it)) { flet save_lctx(m_lctx, m_lctx); - while (is_pi(type)) { - expr fvar = m_lctx.mk_local_decl(ngen(), binding_name(type), binding_domain(type)); - type = type_checker(m_st, m_lctx).whnf(instantiate(binding_body(type), fvar)); + while (is_pi(type_it)) { + expr fvar = m_lctx.mk_local_decl(ngen(), binding_name(type_it), binding_domain(type_it)); + type_it = type_checker(m_st, m_lctx).whnf(instantiate(binding_body(type_it), fvar)); } - if (is_sort(type)) - return true; + if (is_sort(type_it)) + return cache_is_irrelevant(e, true); } - return false; + return cache_is_irrelevant(e, false); } catch (kernel_exception &) { /* failed to infer type or normalize, assume it is relevant */ - return false; + return cache_is_irrelevant(e, false); } } @@ -126,14 +139,21 @@ class erase_irrelevant_fn { return mk_enf_unreachable(); } else if (c == get_lc_proof_name()) { return mk_enf_neutral(); - } - if (is_irrelevant(e)) { + } else if (is_irrelevant(e)) { return mk_enf_neutral(); } else { return mk_constant(const_name(e)); } } + expr visit_fvar(expr const & e) { + if (is_irrelevant(e)) { + return mk_enf_neutral(); + } else { + return e; + } + } + bool is_atom(expr const & e) { switch (e.kind()) { case expr_kind::FVar: return true; @@ -189,7 +209,7 @@ class erase_irrelevant_fn { unsigned minors_begin; unsigned minors_end; std::tie(minors_begin, minors_end) = get_cases_on_minors_range(env(), const_name(c)); if (!is_runtime_builtin_type(I_name) && minors_end == minors_begin + 1) { - expr major = visit(args[minors_begin - 1]); + expr major = args[minors_begin - 1]; lean_assert(is_atom(major)); expr minor = args[minors_begin]; optional fidx = has_trivial_structure(const_name(c).get_prefix()); @@ -245,13 +265,14 @@ class erase_irrelevant_fn { } } - expr visit_app_default(expr const & fn, buffer & args) { + expr visit_app_default(expr fn, buffer & args) { + fn = visit(fn); for (expr & arg : args) { if (!is_atom(arg)) { // In LCNF, relevant arguments are atomic arg = mk_enf_neutral(); - } else if (is_constant(arg)) { - arg = visit_constant(arg); + } else { + arg = visit(arg); } } return mk_app(fn, args); @@ -259,7 +280,7 @@ class erase_irrelevant_fn { expr visit_quot_lift(buffer & args) { lean_assert(args.size() >= 6); - expr f = visit(args[3]); + expr f = args[3]; buffer new_args; for (unsigned i = 5; i < args.size(); i++) new_args.push_back(args[i]); @@ -302,7 +323,6 @@ class erase_irrelevant_fn { return visit_quot_lift(args); } } - f = visit(f); return visit_app_default(f, args); } @@ -313,7 +333,7 @@ class erase_irrelevant_fn { else return visit(proj_expr(e)); } else { - return e; + return update_proj(e, visit(proj_expr(e))); } } @@ -356,19 +376,23 @@ class erase_irrelevant_fn { return visit(instantiate_rev(e, curr_fvars.size(), curr_fvars.data())); } + expr visit_mdata(expr const & e) { + return update_mdata(e, visit(mdata_expr(e))); + } + expr visit(expr const & e) { lean_assert(m_let_entries.size() == m_let_fvars.size()); switch (e.kind()) { case expr_kind::BVar: case expr_kind::MVar: lean_unreachable(); - case expr_kind::FVar: return e; + case expr_kind::FVar: return visit_fvar(e); case expr_kind::Sort: return mk_enf_neutral(); case expr_kind::Lit: return e; case expr_kind::Pi: return mk_enf_neutral(); case expr_kind::Const: return visit_constant(e); case expr_kind::App: return visit_app(e); case expr_kind::Proj: return visit_proj(e); - case expr_kind::MData: return e; + case expr_kind::MData: return visit_mdata(e); case expr_kind::Lambda: return visit_lambda(e); case expr_kind::Let: return visit_let(e); }