diff --git a/src/library/compiler/cse.cpp b/src/library/compiler/cse.cpp index 5ab9183b8c..d603cd9a8e 100644 --- a/src/library/compiler/cse.cpp +++ b/src/library/compiler/cse.cpp @@ -149,6 +149,79 @@ class cse_fn : public compiler_step_visitor { } } + /* Helper functor for converting common subexpressions into fresh let-decls */ + struct cse_processor { + unsigned & m_counter; + expr_struct_set const & m_common_subexprs; + expr_struct_map m_common_subexpr_to_local; + type_context::tmp_locals m_all_locals; /* new local declarations, it also include let-decls for common-subexprs */ + local_context const & m_lctx; + + cse_processor(unsigned & counter, type_context & ctx, expr_struct_set const & s): + m_counter(counter), + m_common_subexprs(s), + m_all_locals(ctx), + m_lctx(ctx.lctx()) { + } + + virtual expr adjust_locals(expr const & v) { + return v; + } + + expr process(expr const & e, optional const & main = none_expr()) { + expr r = replace(e, [&](expr const & s, unsigned) { + if (main && s == *main) return none_expr(); + if (!is_app(s) && !is_macro(s)) return none_expr(); + if (!closed(s)) return none_expr(); + auto it1 = m_common_subexpr_to_local.find(s); + if (it1 != m_common_subexpr_to_local.end()) + return some_expr(it1->second); + if (m_common_subexprs.find(s) == m_common_subexprs.end()) + return none_expr(); + /* Eliminate common subexpressions nested in s */ + expr new_v = process(s, some_expr(s)); + name n = name("_c").append_after(m_counter); + m_counter++; + expr l = m_all_locals.push_let(n, mk_neutral_expr(), new_v); + m_common_subexpr_to_local.insert(mk_pair(s, l)); + return some_expr(l); + }); + return adjust_locals(r); + } + }; + + /* Similar to cse_processor, but has support for binding exprs (lambda and let) */ + struct cse_processor_for_binding : public cse_processor { + type_context::tmp_locals const & m_locals; + buffer m_new_locals; + + cse_processor_for_binding(unsigned & counter, type_context & ctx, type_context::tmp_locals const & locals, expr_struct_set const & s): + cse_processor(counter, ctx, s), + m_locals(locals) { + } + + virtual expr adjust_locals(expr const & v) { + return replace_locals(v, m_new_locals.size(), m_locals.data(), m_new_locals.data()); + } + + void process_locals() { + lean_assert(m_new_locals.empty()); + for (expr const & local : m_locals.as_buffer()) { + local_decl decl = m_lctx.get_local_decl(local); + if (decl.get_value()) { + /* let-entry */ + expr new_v = process(*decl.get_value()); + expr l = m_all_locals.push_let(decl.get_pp_name(), adjust_locals(decl.get_type()), new_v); + m_new_locals.push_back(l); + } else { + /* lambda-entry */ + expr l = m_all_locals.push_local(decl.get_pp_name(), adjust_locals(decl.get_type()), decl.get_info()); + m_new_locals.push_back(l); + } + } + } + }; + expr visit_lambda_let(expr const & e) { type_context::tmp_locals locals(m_ctx); expr t = e; @@ -177,54 +250,10 @@ class cse_fn : public compiler_step_visitor { if (common_subexprs.empty()) return copy_tag(e, locals.mk_lambda(t)); - expr_struct_map common_subexpr_to_local; - buffer new_locals; - type_context::tmp_locals all_locals(m_ctx); /* new local declarations + let-decls for common-subexprs */ - local_context const & lctx = m_ctx.lctx(); - - std::function const &)> - process = [&](expr const & e, optional const & main) { - return replace(e, [&](expr const & s, unsigned) { - if (main && s == *main) return none_expr(); - if (!is_app(s) && !is_macro(s)) return none_expr(); - if (!closed(s)) return none_expr(); - auto it1 = common_subexpr_to_local.find(s); - if (it1 != common_subexpr_to_local.end()) - return some_expr(it1->second); - if (common_subexprs.find(s) == common_subexprs.end()) - return none_expr(); - /* Eliminate common subexpressions nested in s */ - expr new_v = process(s, some_expr(s)); - new_v = replace_locals(new_v, new_locals.size(), locals.data(), new_locals.data()); - name n = name("_c").append_after(m_counter); - m_counter++; - expr l = all_locals.push_let(n, mk_neutral_expr(), new_v); - common_subexpr_to_local.insert(mk_pair(s, l)); - return some_expr(l); - }); - }; - - for (expr const & local : locals.as_buffer()) { - local_decl decl = lctx.get_local_decl(local); - if (decl.get_value()) { - /* let-entry */ - expr new_v = process(*decl.get_value(), none_expr()); - expr l = all_locals.push_let(decl.get_pp_name(), - replace_locals(decl.get_type(), new_locals.size(), locals.data(), new_locals.data()), - replace_locals(new_v, new_locals.size(), locals.data(), new_locals.data())); - new_locals.push_back(l); - } else { - /* lambda-entry */ - expr l = all_locals.push_local(decl.get_pp_name(), - replace_locals(decl.get_type(), new_locals.size(), locals.data(), new_locals.data()), - decl.get_info()); - new_locals.push_back(l); - } - } - - expr new_t = process(t, none_expr()); - new_t = replace_locals(new_t, new_locals.size(), locals.data(), new_locals.data()); - return copy_tag(e, all_locals.mk_lambda(new_t)); + cse_processor_for_binding proc(m_counter, m_ctx, locals, common_subexprs); + proc.process_locals(); + expr new_t = proc.process(t); + return copy_tag(e, proc.m_all_locals.mk_lambda(new_t)); } virtual expr visit_lambda(expr const & e) override {