feat(library/compiler/csimp): eliminate join points that are used only once
This commit is contained in:
parent
44cdb1fc56
commit
d3ba9ef7fa
2 changed files with 138 additions and 3 deletions
|
|
@ -24,7 +24,7 @@ inline expr instantiate_rev(expr const & e, buffer<expr> const & s) {
|
|||
return instantiate_rev(e, s.size(), s.data());
|
||||
}
|
||||
|
||||
expr apply_beta(expr f, unsigned num_args, expr const * args);
|
||||
expr apply_beta(expr f, unsigned num_rev_args, expr const * rev_args);
|
||||
bool is_head_beta(expr const & t);
|
||||
expr head_beta_reduce(expr const & t);
|
||||
/* If `e` is of the form `(fun x, t) a` return `head_beta_const_fn(t)` if `t` does not depend on `x`,
|
||||
|
|
|
|||
|
|
@ -1513,7 +1513,142 @@ public:
|
|||
}
|
||||
}
|
||||
};
|
||||
expr csimp_core(environment const & env, local_ctx const & lctx, expr const & e, bool before_erasure, csimp_cfg const & cfg) {
|
||||
return csimp_fn(env, lctx, before_erasure, cfg)(e);
|
||||
|
||||
/* Eliminate join-points that are used only once */
|
||||
class elim_jp1_fn {
|
||||
environment const & m_env;
|
||||
local_ctx m_lctx;
|
||||
bool m_before_erasure;
|
||||
name_generator m_ngen;
|
||||
name_set m_to_expand;
|
||||
bool m_expanded{false};
|
||||
|
||||
void mark_to_expand(expr const & e) {
|
||||
m_to_expand.insert(fvar_name(e));
|
||||
}
|
||||
|
||||
bool is_to_expand_jp_app(expr const & e) {
|
||||
expr const & f = get_app_fn(e);
|
||||
return is_fvar(f) && m_to_expand.contains(fvar_name(f));
|
||||
}
|
||||
|
||||
expr visit_lambda(expr e) {
|
||||
buffer<expr> fvars;
|
||||
while (is_lambda(e)) {
|
||||
expr domain = visit(instantiate_rev(binding_domain(e), fvars.size(), fvars.data()));
|
||||
expr fvar = m_lctx.mk_local_decl(m_ngen, binding_name(e), domain, binding_info(e));
|
||||
fvars.push_back(fvar);
|
||||
e = binding_body(e);
|
||||
}
|
||||
e = visit(instantiate_rev(e, fvars.size(), fvars.data()));
|
||||
return m_lctx.mk_lambda(fvars, e);
|
||||
}
|
||||
|
||||
expr visit_cases(expr const & e) {
|
||||
lean_assert(is_cases_on_app(m_env, e));
|
||||
buffer<expr> args;
|
||||
expr const & c = get_app_args(e, args);
|
||||
/* simplify minor premises */
|
||||
unsigned minor_idx; unsigned minors_end;
|
||||
std::tie(minor_idx, minors_end) = get_cases_on_minors_range(m_env, const_name(c), m_before_erasure);
|
||||
for (; minor_idx < minors_end; minor_idx++) {
|
||||
args[minor_idx] = visit(args[minor_idx]);
|
||||
}
|
||||
return mk_app(c, args);
|
||||
}
|
||||
|
||||
expr visit_app(expr const & e) {
|
||||
lean_assert(is_app(e));
|
||||
if (is_cases_on_app(m_env, e)) {
|
||||
return visit_cases(e);
|
||||
} else if (is_to_expand_jp_app(e)) {
|
||||
buffer<expr> args;
|
||||
expr const & jp = get_app_rev_args(e, args);
|
||||
local_decl jp_decl = m_lctx.get_local_decl(jp);
|
||||
lean_assert(is_join_point_name(jp_decl.get_user_name()));
|
||||
lean_assert(jp_decl.get_value());
|
||||
lean_assert(is_lambda(*jp_decl.get_value()));
|
||||
return apply_beta(*jp_decl.get_value(), args.size(), args.data());
|
||||
} else {
|
||||
return e;
|
||||
}
|
||||
}
|
||||
|
||||
bool at_most_once(expr const & e, expr const & jp) {
|
||||
lean_assert(is_fvar(jp));
|
||||
bool found = false;
|
||||
bool result = true;
|
||||
for_each(e, [&](expr const & e, unsigned) {
|
||||
if (!has_fvar(e)) return false;
|
||||
if (result == false) return false; /* stop search */
|
||||
if (e == jp) {
|
||||
if (found) result = false;
|
||||
else found = true;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
expr visit_let(expr e) {
|
||||
buffer<expr> fvars;
|
||||
buffer<expr> all_fvars;
|
||||
while (is_let(e)) {
|
||||
expr new_type = visit(instantiate_rev(let_type(e), fvars.size(), fvars.data()));
|
||||
expr new_val = visit(instantiate_rev(let_value(e), fvars.size(), fvars.data()));
|
||||
expr fvar = m_lctx.mk_local_decl(m_ngen, let_name(e), new_type, new_val);
|
||||
fvars.push_back(fvar);
|
||||
if (is_join_point_name(let_name(e))) {
|
||||
e = instantiate_rev(let_body(e), fvars.size(), fvars.data());
|
||||
fvars.clear();
|
||||
if (at_most_once(e, fvar)) {
|
||||
m_expanded = true;
|
||||
mark_to_expand(fvar);
|
||||
} else {
|
||||
/* Keep join point */
|
||||
all_fvars.push_back(fvar);
|
||||
}
|
||||
} else {
|
||||
all_fvars.push_back(fvar);
|
||||
e = let_body(e);
|
||||
}
|
||||
}
|
||||
e = instantiate_rev(e, fvars.size(), fvars.data());
|
||||
e = visit(e);
|
||||
return m_lctx.mk_lambda(all_fvars, e);
|
||||
}
|
||||
|
||||
expr visit(expr const & e) {
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Lambda: return visit_lambda(e);
|
||||
case expr_kind::Let: return visit_let(e);
|
||||
case expr_kind::App: return visit_app(e);
|
||||
default: return e;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
elim_jp1_fn(environment const & env, local_ctx const & lctx, bool before_erasure):
|
||||
m_env(env), m_lctx(lctx), m_before_erasure(before_erasure) {}
|
||||
expr operator()(expr const & e) {
|
||||
m_expanded = false;
|
||||
return visit(e);
|
||||
}
|
||||
|
||||
bool expanded() const { return m_expanded; }
|
||||
};
|
||||
|
||||
expr csimp_core(environment const & env, local_ctx const & lctx, expr const & e0, bool before_erasure, csimp_cfg const & cfg) {
|
||||
csimp_fn simp(env, lctx, before_erasure, cfg);
|
||||
elim_jp1_fn elim_jp1(env, lctx, before_erasure);
|
||||
expr e = e0;
|
||||
while (true) {
|
||||
e = simp(e);
|
||||
expr old_e = e;
|
||||
e = elim_jp1(e);
|
||||
if (!elim_jp1.expanded())
|
||||
return e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue