feat(library/compiler/csimp): eliminate join points that are used only once

This commit is contained in:
Leonardo de Moura 2019-03-14 10:43:35 -07:00
parent 44cdb1fc56
commit d3ba9ef7fa
2 changed files with 138 additions and 3 deletions

View file

@ -24,7 +24,7 @@ inline expr instantiate_rev(expr const & e, buffer<expr> const & s) {
return instantiate_rev(e, s.size(), s.data());
}
expr apply_beta(expr f, unsigned num_args, expr const * args);
expr apply_beta(expr f, unsigned num_rev_args, expr const * rev_args);
bool is_head_beta(expr const & t);
expr head_beta_reduce(expr const & t);
/* If `e` is of the form `(fun x, t) a` return `head_beta_const_fn(t)` if `t` does not depend on `x`,

View file

@ -1513,7 +1513,142 @@ public:
}
}
};
expr csimp_core(environment const & env, local_ctx const & lctx, expr const & e, bool before_erasure, csimp_cfg const & cfg) {
return csimp_fn(env, lctx, before_erasure, cfg)(e);
/* Eliminate join-points that are used only once */
class elim_jp1_fn {
environment const & m_env;
local_ctx m_lctx;
bool m_before_erasure;
name_generator m_ngen;
name_set m_to_expand;
bool m_expanded{false};
void mark_to_expand(expr const & e) {
m_to_expand.insert(fvar_name(e));
}
bool is_to_expand_jp_app(expr const & e) {
expr const & f = get_app_fn(e);
return is_fvar(f) && m_to_expand.contains(fvar_name(f));
}
expr visit_lambda(expr e) {
buffer<expr> fvars;
while (is_lambda(e)) {
expr domain = visit(instantiate_rev(binding_domain(e), fvars.size(), fvars.data()));
expr fvar = m_lctx.mk_local_decl(m_ngen, binding_name(e), domain, binding_info(e));
fvars.push_back(fvar);
e = binding_body(e);
}
e = visit(instantiate_rev(e, fvars.size(), fvars.data()));
return m_lctx.mk_lambda(fvars, e);
}
expr visit_cases(expr const & e) {
lean_assert(is_cases_on_app(m_env, e));
buffer<expr> args;
expr const & c = get_app_args(e, args);
/* simplify minor premises */
unsigned minor_idx; unsigned minors_end;
std::tie(minor_idx, minors_end) = get_cases_on_minors_range(m_env, const_name(c), m_before_erasure);
for (; minor_idx < minors_end; minor_idx++) {
args[minor_idx] = visit(args[minor_idx]);
}
return mk_app(c, args);
}
expr visit_app(expr const & e) {
lean_assert(is_app(e));
if (is_cases_on_app(m_env, e)) {
return visit_cases(e);
} else if (is_to_expand_jp_app(e)) {
buffer<expr> args;
expr const & jp = get_app_rev_args(e, args);
local_decl jp_decl = m_lctx.get_local_decl(jp);
lean_assert(is_join_point_name(jp_decl.get_user_name()));
lean_assert(jp_decl.get_value());
lean_assert(is_lambda(*jp_decl.get_value()));
return apply_beta(*jp_decl.get_value(), args.size(), args.data());
} else {
return e;
}
}
bool at_most_once(expr const & e, expr const & jp) {
lean_assert(is_fvar(jp));
bool found = false;
bool result = true;
for_each(e, [&](expr const & e, unsigned) {
if (!has_fvar(e)) return false;
if (result == false) return false; /* stop search */
if (e == jp) {
if (found) result = false;
else found = true;
return false;
}
return true;
});
return result;
}
expr visit_let(expr e) {
buffer<expr> fvars;
buffer<expr> all_fvars;
while (is_let(e)) {
expr new_type = visit(instantiate_rev(let_type(e), fvars.size(), fvars.data()));
expr new_val = visit(instantiate_rev(let_value(e), fvars.size(), fvars.data()));
expr fvar = m_lctx.mk_local_decl(m_ngen, let_name(e), new_type, new_val);
fvars.push_back(fvar);
if (is_join_point_name(let_name(e))) {
e = instantiate_rev(let_body(e), fvars.size(), fvars.data());
fvars.clear();
if (at_most_once(e, fvar)) {
m_expanded = true;
mark_to_expand(fvar);
} else {
/* Keep join point */
all_fvars.push_back(fvar);
}
} else {
all_fvars.push_back(fvar);
e = let_body(e);
}
}
e = instantiate_rev(e, fvars.size(), fvars.data());
e = visit(e);
return m_lctx.mk_lambda(all_fvars, e);
}
expr visit(expr const & e) {
switch (e.kind()) {
case expr_kind::Lambda: return visit_lambda(e);
case expr_kind::Let: return visit_let(e);
case expr_kind::App: return visit_app(e);
default: return e;
}
}
public:
elim_jp1_fn(environment const & env, local_ctx const & lctx, bool before_erasure):
m_env(env), m_lctx(lctx), m_before_erasure(before_erasure) {}
expr operator()(expr const & e) {
m_expanded = false;
return visit(e);
}
bool expanded() const { return m_expanded; }
};
expr csimp_core(environment const & env, local_ctx const & lctx, expr const & e0, bool before_erasure, csimp_cfg const & cfg) {
csimp_fn simp(env, lctx, before_erasure, cfg);
elim_jp1_fn elim_jp1(env, lctx, before_erasure);
expr e = e0;
while (true) {
e = simp(e);
expr old_e = e;
e = elim_jp1(e);
if (!elim_jp1.expanded())
return e;
}
}
}