refactor(library/compiler/cse): isolate cse processor procedure

We will use it to fix a performance problem in case expressions for
inductive datatypes that have constructors without data.
This commit is contained in:
Leonardo de Moura 2017-11-09 12:41:36 -08:00
parent fabf7f6380
commit 76ddda493d

View file

@ -149,6 +149,79 @@ class cse_fn : public compiler_step_visitor {
}
}
/* Helper functor for converting common subexpressions into fresh let-decls */
struct cse_processor {
unsigned & m_counter;
expr_struct_set const & m_common_subexprs;
expr_struct_map<expr> m_common_subexpr_to_local;
type_context::tmp_locals m_all_locals; /* new local declarations, it also include let-decls for common-subexprs */
local_context const & m_lctx;
cse_processor(unsigned & counter, type_context & ctx, expr_struct_set const & s):
m_counter(counter),
m_common_subexprs(s),
m_all_locals(ctx),
m_lctx(ctx.lctx()) {
}
virtual expr adjust_locals(expr const & v) {
return v;
}
expr process(expr const & e, optional<expr> const & main = none_expr()) {
expr r = replace(e, [&](expr const & s, unsigned) {
if (main && s == *main) return none_expr();
if (!is_app(s) && !is_macro(s)) return none_expr();
if (!closed(s)) return none_expr();
auto it1 = m_common_subexpr_to_local.find(s);
if (it1 != m_common_subexpr_to_local.end())
return some_expr(it1->second);
if (m_common_subexprs.find(s) == m_common_subexprs.end())
return none_expr();
/* Eliminate common subexpressions nested in s */
expr new_v = process(s, some_expr(s));
name n = name("_c").append_after(m_counter);
m_counter++;
expr l = m_all_locals.push_let(n, mk_neutral_expr(), new_v);
m_common_subexpr_to_local.insert(mk_pair(s, l));
return some_expr(l);
});
return adjust_locals(r);
}
};
/* Similar to cse_processor, but has support for binding exprs (lambda and let) */
struct cse_processor_for_binding : public cse_processor {
type_context::tmp_locals const & m_locals;
buffer<expr> m_new_locals;
cse_processor_for_binding(unsigned & counter, type_context & ctx, type_context::tmp_locals const & locals, expr_struct_set const & s):
cse_processor(counter, ctx, s),
m_locals(locals) {
}
virtual expr adjust_locals(expr const & v) {
return replace_locals(v, m_new_locals.size(), m_locals.data(), m_new_locals.data());
}
void process_locals() {
lean_assert(m_new_locals.empty());
for (expr const & local : m_locals.as_buffer()) {
local_decl decl = m_lctx.get_local_decl(local);
if (decl.get_value()) {
/* let-entry */
expr new_v = process(*decl.get_value());
expr l = m_all_locals.push_let(decl.get_pp_name(), adjust_locals(decl.get_type()), new_v);
m_new_locals.push_back(l);
} else {
/* lambda-entry */
expr l = m_all_locals.push_local(decl.get_pp_name(), adjust_locals(decl.get_type()), decl.get_info());
m_new_locals.push_back(l);
}
}
}
};
expr visit_lambda_let(expr const & e) {
type_context::tmp_locals locals(m_ctx);
expr t = e;
@ -177,54 +250,10 @@ class cse_fn : public compiler_step_visitor {
if (common_subexprs.empty())
return copy_tag(e, locals.mk_lambda(t));
expr_struct_map<expr> common_subexpr_to_local;
buffer<expr> new_locals;
type_context::tmp_locals all_locals(m_ctx); /* new local declarations + let-decls for common-subexprs */
local_context const & lctx = m_ctx.lctx();
std::function<expr(expr const &, optional<expr> const &)>
process = [&](expr const & e, optional<expr> const & main) {
return replace(e, [&](expr const & s, unsigned) {
if (main && s == *main) return none_expr();
if (!is_app(s) && !is_macro(s)) return none_expr();
if (!closed(s)) return none_expr();
auto it1 = common_subexpr_to_local.find(s);
if (it1 != common_subexpr_to_local.end())
return some_expr(it1->second);
if (common_subexprs.find(s) == common_subexprs.end())
return none_expr();
/* Eliminate common subexpressions nested in s */
expr new_v = process(s, some_expr(s));
new_v = replace_locals(new_v, new_locals.size(), locals.data(), new_locals.data());
name n = name("_c").append_after(m_counter);
m_counter++;
expr l = all_locals.push_let(n, mk_neutral_expr(), new_v);
common_subexpr_to_local.insert(mk_pair(s, l));
return some_expr(l);
});
};
for (expr const & local : locals.as_buffer()) {
local_decl decl = lctx.get_local_decl(local);
if (decl.get_value()) {
/* let-entry */
expr new_v = process(*decl.get_value(), none_expr());
expr l = all_locals.push_let(decl.get_pp_name(),
replace_locals(decl.get_type(), new_locals.size(), locals.data(), new_locals.data()),
replace_locals(new_v, new_locals.size(), locals.data(), new_locals.data()));
new_locals.push_back(l);
} else {
/* lambda-entry */
expr l = all_locals.push_local(decl.get_pp_name(),
replace_locals(decl.get_type(), new_locals.size(), locals.data(), new_locals.data()),
decl.get_info());
new_locals.push_back(l);
}
}
expr new_t = process(t, none_expr());
new_t = replace_locals(new_t, new_locals.size(), locals.data(), new_locals.data());
return copy_tag(e, all_locals.mk_lambda(new_t));
cse_processor_for_binding proc(m_counter, m_ctx, locals, common_subexprs);
proc.process_locals();
expr new_t = proc.process(t);
return copy_tag(e, proc.m_all_locals.mk_lambda(new_t));
}
virtual expr visit_lambda(expr const & e) override {