diff --git a/src/kernel/instantiate.h b/src/kernel/instantiate.h index dbd050caa2..d00d8347e4 100644 --- a/src/kernel/instantiate.h +++ b/src/kernel/instantiate.h @@ -24,7 +24,7 @@ inline expr instantiate_rev(expr const & e, buffer const & s) { return instantiate_rev(e, s.size(), s.data()); } -expr apply_beta(expr f, unsigned num_args, expr const * args); +expr apply_beta(expr f, unsigned num_rev_args, expr const * rev_args); bool is_head_beta(expr const & t); expr head_beta_reduce(expr const & t); /* If `e` is of the form `(fun x, t) a` return `head_beta_const_fn(t)` if `t` does not depend on `x`, diff --git a/src/library/compiler/csimp.cpp b/src/library/compiler/csimp.cpp index 6c792140e7..cc18f48af7 100644 --- a/src/library/compiler/csimp.cpp +++ b/src/library/compiler/csimp.cpp @@ -1513,7 +1513,142 @@ public: } } }; -expr csimp_core(environment const & env, local_ctx const & lctx, expr const & e, bool before_erasure, csimp_cfg const & cfg) { - return csimp_fn(env, lctx, before_erasure, cfg)(e); + +/* Eliminate join-points that are used only once */ +class elim_jp1_fn { + environment const & m_env; + local_ctx m_lctx; + bool m_before_erasure; + name_generator m_ngen; + name_set m_to_expand; + bool m_expanded{false}; + + void mark_to_expand(expr const & e) { + m_to_expand.insert(fvar_name(e)); + } + + bool is_to_expand_jp_app(expr const & e) { + expr const & f = get_app_fn(e); + return is_fvar(f) && m_to_expand.contains(fvar_name(f)); + } + + expr visit_lambda(expr e) { + buffer fvars; + while (is_lambda(e)) { + expr domain = visit(instantiate_rev(binding_domain(e), fvars.size(), fvars.data())); + expr fvar = m_lctx.mk_local_decl(m_ngen, binding_name(e), domain, binding_info(e)); + fvars.push_back(fvar); + e = binding_body(e); + } + e = visit(instantiate_rev(e, fvars.size(), fvars.data())); + return m_lctx.mk_lambda(fvars, e); + } + + expr visit_cases(expr const & e) { + lean_assert(is_cases_on_app(m_env, e)); + buffer args; + expr const & c = get_app_args(e, args); + /* simplify minor premises */ + unsigned minor_idx; unsigned minors_end; + std::tie(minor_idx, minors_end) = get_cases_on_minors_range(m_env, const_name(c), m_before_erasure); + for (; minor_idx < minors_end; minor_idx++) { + args[minor_idx] = visit(args[minor_idx]); + } + return mk_app(c, args); + } + + expr visit_app(expr const & e) { + lean_assert(is_app(e)); + if (is_cases_on_app(m_env, e)) { + return visit_cases(e); + } else if (is_to_expand_jp_app(e)) { + buffer args; + expr const & jp = get_app_rev_args(e, args); + local_decl jp_decl = m_lctx.get_local_decl(jp); + lean_assert(is_join_point_name(jp_decl.get_user_name())); + lean_assert(jp_decl.get_value()); + lean_assert(is_lambda(*jp_decl.get_value())); + return apply_beta(*jp_decl.get_value(), args.size(), args.data()); + } else { + return e; + } + } + + bool at_most_once(expr const & e, expr const & jp) { + lean_assert(is_fvar(jp)); + bool found = false; + bool result = true; + for_each(e, [&](expr const & e, unsigned) { + if (!has_fvar(e)) return false; + if (result == false) return false; /* stop search */ + if (e == jp) { + if (found) result = false; + else found = true; + return false; + } + return true; + }); + return result; + } + + expr visit_let(expr e) { + buffer fvars; + buffer all_fvars; + while (is_let(e)) { + expr new_type = visit(instantiate_rev(let_type(e), fvars.size(), fvars.data())); + expr new_val = visit(instantiate_rev(let_value(e), fvars.size(), fvars.data())); + expr fvar = m_lctx.mk_local_decl(m_ngen, let_name(e), new_type, new_val); + fvars.push_back(fvar); + if (is_join_point_name(let_name(e))) { + e = instantiate_rev(let_body(e), fvars.size(), fvars.data()); + fvars.clear(); + if (at_most_once(e, fvar)) { + m_expanded = true; + mark_to_expand(fvar); + } else { + /* Keep join point */ + all_fvars.push_back(fvar); + } + } else { + all_fvars.push_back(fvar); + e = let_body(e); + } + } + e = instantiate_rev(e, fvars.size(), fvars.data()); + e = visit(e); + return m_lctx.mk_lambda(all_fvars, e); + } + + expr visit(expr const & e) { + switch (e.kind()) { + case expr_kind::Lambda: return visit_lambda(e); + case expr_kind::Let: return visit_let(e); + case expr_kind::App: return visit_app(e); + default: return e; + } + } + +public: + elim_jp1_fn(environment const & env, local_ctx const & lctx, bool before_erasure): + m_env(env), m_lctx(lctx), m_before_erasure(before_erasure) {} + expr operator()(expr const & e) { + m_expanded = false; + return visit(e); + } + + bool expanded() const { return m_expanded; } +}; + +expr csimp_core(environment const & env, local_ctx const & lctx, expr const & e0, bool before_erasure, csimp_cfg const & cfg) { + csimp_fn simp(env, lctx, before_erasure, cfg); + elim_jp1_fn elim_jp1(env, lctx, before_erasure); + expr e = e0; + while (true) { + e = simp(e); + expr old_e = e; + e = elim_jp1(e); + if (!elim_jp1.expanded()) + return e; + } } }