feat(library/compiler): add cce: common case elimination
This commit is contained in:
parent
1534f17a89
commit
ff2e28e557
6 changed files with 282 additions and 17 deletions
|
|
@ -6,11 +6,15 @@ Author: Leonardo de Moura
|
|||
*/
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "runtime/flet.h"
|
||||
#include "util/name_generator.h"
|
||||
#include "kernel/environment.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "kernel/abstract.h"
|
||||
#include "kernel/for_each_fn.h"
|
||||
#include "kernel/replace_fn.h"
|
||||
#include "kernel/expr_maps.h"
|
||||
#include "kernel/expr_sets.h"
|
||||
#include "library/compiler/util.h"
|
||||
|
||||
namespace lean {
|
||||
|
|
@ -131,6 +135,249 @@ expr cse(environment const & env, expr const & e) {
|
|||
return cse_fn(env)(e);
|
||||
}
|
||||
|
||||
/* Common case elimination.
|
||||
|
||||
This transformation creates join-points for identical minor premises.
|
||||
This is important in code such as
|
||||
```
|
||||
def get_fn : expr -> tactic expr
|
||||
| (expr.app f _) := pure f
|
||||
| _ := throw "expr is not an application"
|
||||
```
|
||||
The "else"-branch is duplicated by the equation compiler for each constructor different from `expr.app`. */
|
||||
class cce_fn {
|
||||
type_checker::state m_st;
|
||||
local_ctx m_lctx;
|
||||
buffer<expr> m_fvars;
|
||||
expr_map<bool> m_cce_candidates;
|
||||
buffer<expr> m_cce_targets;
|
||||
name m_j;
|
||||
unsigned m_next_idx{1};
|
||||
public:
|
||||
environment & env() { return m_st.env(); }
|
||||
|
||||
name_generator & ngen() { return m_st.ngen(); }
|
||||
|
||||
unsigned get_fvar_idx(expr const & x) {
|
||||
return m_lctx.get_local_decl(x).get_idx();
|
||||
}
|
||||
|
||||
unsigned get_max_fvar_idx(expr const & e) {
|
||||
if (!has_fvar(e))
|
||||
return 0;
|
||||
unsigned r = 0;
|
||||
for_each(e, [&](expr const & x, unsigned) {
|
||||
if (!has_fvar(x)) return false;
|
||||
if (is_fvar(x)) {
|
||||
unsigned x_idx = get_fvar_idx(x);
|
||||
if (x_idx > r)
|
||||
r = x_idx;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return r;
|
||||
}
|
||||
|
||||
expr replace_target(expr const & e, expr const & target, expr const & jmp) {
|
||||
return replace(e, [&](expr const & t, unsigned) {
|
||||
if (target == t) {
|
||||
return some_expr(jmp);
|
||||
}
|
||||
return none_expr();
|
||||
});
|
||||
}
|
||||
|
||||
expr mk_let_lambda(unsigned old_fvars_size, expr body, bool is_let) {
|
||||
lean_assert(m_fvars.size() >= old_fvars_size);
|
||||
if (m_fvars.size() == old_fvars_size)
|
||||
return body;
|
||||
unsigned first_var_idx;
|
||||
if (old_fvars_size == 0)
|
||||
first_var_idx = 0;
|
||||
else
|
||||
first_var_idx = get_fvar_idx(m_fvars[old_fvars_size]);
|
||||
unsigned j = 0;
|
||||
buffer<pair<expr, expr>> target_jmp_pairs;
|
||||
name_set new_fvar_names;
|
||||
for (unsigned i = 0; i < m_cce_targets.size(); i++) {
|
||||
expr const & target = m_cce_targets[i];
|
||||
unsigned max_idx = get_max_fvar_idx(target);
|
||||
if (max_idx >= first_var_idx) {
|
||||
expr target_type = cheap_beta_reduce(type_checker(m_st, m_lctx).infer(target));
|
||||
expr unit = mk_unit(mk_level_one());
|
||||
expr unit_mk = mk_unit_mk(mk_level_one());
|
||||
expr new_val = ::lean::mk_lambda("u", unit, target);
|
||||
expr new_type = ::lean::mk_arrow(unit, target_type);
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), mk_join_point_name(m_j.append_after(m_next_idx)), new_type, new_val);
|
||||
new_fvar_names.insert(fvar_name(new_fvar));
|
||||
expr jmp = ::lean::mk_let("_j", target_type, mk_app(new_fvar, unit_mk), mk_bvar(0));
|
||||
if (is_let) {
|
||||
/* We must insert new_fvar after fvar with idx == max_idx */
|
||||
m_next_idx++;
|
||||
unsigned k = old_fvars_size;
|
||||
for (; k < m_fvars.size(); k++) {
|
||||
expr const & fvar = m_fvars[k];
|
||||
if (get_fvar_idx(fvar) > max_idx) {
|
||||
m_fvars.insert(k, new_fvar);
|
||||
/* We need to save the pairs to replace the `target` on let-declarations that occurr after k */
|
||||
target_jmp_pairs.emplace_back(target, jmp);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (k == m_fvars.size()) {
|
||||
m_fvars.push_back(new_fvar);
|
||||
}
|
||||
} else {
|
||||
lean_assert(!is_let);
|
||||
/* For lambda we add new free variable after lambda vars */
|
||||
m_fvars.push_back(new_fvar);
|
||||
}
|
||||
body = replace_target(body, target, jmp);
|
||||
} else {
|
||||
m_cce_targets[j] = target;
|
||||
j++;
|
||||
}
|
||||
}
|
||||
m_cce_targets.shrink(j);
|
||||
if (is_let && !target_jmp_pairs.empty()) {
|
||||
expr r = abstract(body, m_fvars.size() - old_fvars_size, m_fvars.data() + old_fvars_size);
|
||||
unsigned i = m_fvars.size();
|
||||
while (i > old_fvars_size) {
|
||||
--i;
|
||||
expr fvar = m_fvars[i];
|
||||
local_decl decl = m_lctx.get_local_decl(fvar);
|
||||
expr type = abstract(decl.get_type(), i - old_fvars_size, m_fvars.data() + old_fvars_size);
|
||||
lean_assert(decl.get_value());
|
||||
expr val = *decl.get_value();
|
||||
if ((!new_fvar_names.contains(fvar_name(fvar))) &&
|
||||
(is_lambda(val) || is_cases_on_app(env(), val))) {
|
||||
for (pair<expr, expr> const & p : target_jmp_pairs) {
|
||||
val = replace_target(val, p.first, p.second);
|
||||
}
|
||||
}
|
||||
val = abstract(val, i - old_fvars_size, m_fvars.data() + old_fvars_size);
|
||||
r = ::lean::mk_let(decl.get_user_name(), type, val, r);
|
||||
}
|
||||
m_fvars.shrink(old_fvars_size);
|
||||
return r;
|
||||
} else {
|
||||
expr r = m_lctx.mk_lambda(m_fvars.size() - old_fvars_size, m_fvars.data() + old_fvars_size, body);
|
||||
m_fvars.shrink(old_fvars_size);
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
expr mk_let(unsigned old_fvars_size, expr const & body) { return mk_let_lambda(old_fvars_size, body, true); }
|
||||
|
||||
expr mk_lambda(unsigned old_fvars_size, expr const & body) { return mk_let_lambda(old_fvars_size, body, false); }
|
||||
|
||||
expr visit_let(expr e) {
|
||||
buffer<expr> let_fvars;
|
||||
while (is_let(e)) {
|
||||
expr new_type = instantiate_rev(let_type(e), let_fvars.size(), let_fvars.data());
|
||||
expr new_val = visit_let_value(instantiate_rev(let_value(e), let_fvars.size(), let_fvars.data()));
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), let_name(e), new_type, new_val);
|
||||
let_fvars.push_back(new_fvar);
|
||||
m_fvars.push_back(new_fvar);
|
||||
e = let_body(e);
|
||||
}
|
||||
return instantiate_rev(e, let_fvars.size(), let_fvars.data());
|
||||
}
|
||||
|
||||
expr visit_lambda(expr e) {
|
||||
lean_assert(is_lambda(e));
|
||||
flet<local_ctx> save_lctx(m_lctx, m_lctx);
|
||||
unsigned fvars_sz1 = m_fvars.size();
|
||||
while (is_lambda(e)) {
|
||||
/* Types are ignored in compilation steps. So, we do not invoke visit for d. */
|
||||
expr new_d = instantiate_rev(binding_domain(e), m_fvars.size() - fvars_sz1, m_fvars.data() + fvars_sz1);
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), new_d, binding_info(e));
|
||||
m_fvars.push_back(new_fvar);
|
||||
e = binding_body(e);
|
||||
}
|
||||
unsigned fvars_sz2 = m_fvars.size();
|
||||
expr new_body = visit(instantiate_rev(e, m_fvars.size() - fvars_sz1, m_fvars.data() + fvars_sz1));
|
||||
new_body = mk_let(fvars_sz2, new_body);
|
||||
return mk_lambda(fvars_sz1, new_body);
|
||||
}
|
||||
|
||||
void add_candidate(expr const & e) {
|
||||
auto it = m_cce_candidates.find(e);
|
||||
if (it == m_cce_candidates.end()) {
|
||||
m_cce_candidates.insert(mk_pair(e, true));
|
||||
} else if (it->second) {
|
||||
m_cce_targets.push_back(e);
|
||||
it->second = false;
|
||||
}
|
||||
}
|
||||
|
||||
expr visit_app(expr const & e) {
|
||||
if (!is_cases_on_app(env(), e)) return e;
|
||||
buffer<expr> args;
|
||||
expr const & c = get_app_args(e, args);
|
||||
lean_assert(is_constant(c));
|
||||
inductive_val I_val = env().get(const_name(c).get_prefix()).to_inductive_val();
|
||||
unsigned motive_idx = I_val.get_nparams();
|
||||
unsigned first_index = motive_idx + 1;
|
||||
unsigned nindices = I_val.get_nindices();
|
||||
unsigned major_idx = first_index + nindices;
|
||||
unsigned first_minor_idx = major_idx + 1;
|
||||
unsigned nminors = length(I_val.get_cnstrs());
|
||||
/* visit minor premises */
|
||||
for (unsigned i = 0; i < nminors; i++) {
|
||||
unsigned minor_idx = first_minor_idx + i;
|
||||
expr minor = args[minor_idx];
|
||||
flet<local_ctx> save_lctx(m_lctx, m_lctx);
|
||||
unsigned fvars_sz1 = m_fvars.size();
|
||||
while (is_lambda(minor)) {
|
||||
expr new_d = instantiate_rev(binding_domain(minor), m_fvars.size() - fvars_sz1, m_fvars.data() + fvars_sz1);
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(minor), new_d, binding_info(minor));
|
||||
m_fvars.push_back(new_fvar);
|
||||
minor = binding_body(minor);
|
||||
}
|
||||
bool is_cce_target = !has_loose_bvars(minor);
|
||||
unsigned fvars_sz2 = m_fvars.size();
|
||||
expr new_minor = visit(instantiate_rev(minor, m_fvars.size() - fvars_sz1, m_fvars.data() + fvars_sz1));
|
||||
new_minor = mk_let(fvars_sz2, new_minor);
|
||||
if (is_cce_target && !is_lcnf_atom(new_minor))
|
||||
add_candidate(new_minor);
|
||||
new_minor = mk_lambda(fvars_sz1, new_minor);
|
||||
args[minor_idx] = new_minor;
|
||||
}
|
||||
return mk_app(c, args);
|
||||
}
|
||||
|
||||
expr visit_let_value(expr const & e) {
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Lambda: return visit_lambda(e);
|
||||
case expr_kind::App: return visit_app(e);
|
||||
default: return e;
|
||||
}
|
||||
}
|
||||
|
||||
expr visit(expr const & e) {
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Lambda: return visit_lambda(e);
|
||||
case expr_kind::Let: return visit_let(e);
|
||||
default: return e;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
cce_fn(environment const & env, local_ctx const & lctx):
|
||||
m_st(env), m_lctx(lctx), m_j("_j") {
|
||||
}
|
||||
|
||||
expr operator()(expr const & e) {
|
||||
expr r = visit(e);
|
||||
return mk_let(0, r);
|
||||
}
|
||||
};
|
||||
|
||||
expr cce(environment const & env, local_ctx const & lctx, expr const & e) {
|
||||
return cce_fn(env, lctx)(e);
|
||||
}
|
||||
|
||||
void initialize_cse() {
|
||||
g_cse_fresh = new name("_cse_fresh");
|
||||
register_name_generator_prefix(*g_cse_fresh);
|
||||
|
|
|
|||
|
|
@ -7,7 +7,10 @@ Author: Leonardo de Moura
|
|||
#pragma once
|
||||
#include "kernel/environment.h"
|
||||
namespace lean {
|
||||
/* Common subexpression elimination */
|
||||
expr cse(environment const & env, expr const & e);
|
||||
/* Common case elimination */
|
||||
expr cce(environment const & env, local_ctx const & lctx, expr const & e);
|
||||
void initialize_cse();
|
||||
void finalize_cse();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,8 @@ class csimp_fn {
|
|||
if (is_fvar(e)) {
|
||||
if (optional<local_decl> decl = m_lctx.find_local_decl(e)) {
|
||||
if (optional<expr> v = decl->get_value())
|
||||
return find(*v, skip_mdata);
|
||||
if (!is_join_point_name(decl->get_user_name()))
|
||||
return find(*v, skip_mdata);
|
||||
}
|
||||
} else if (is_mdata(e) && skip_mdata) {
|
||||
return find(mdata_expr(e), true);
|
||||
|
|
@ -131,7 +132,9 @@ class csimp_fn {
|
|||
if (is_lcnf_atom(new_val)) {
|
||||
let_fvars.push_back(new_val);
|
||||
} else {
|
||||
name n = is_internal_name(let_name(e)) ? next_name() : let_name(e);
|
||||
name n = let_name(e);
|
||||
if (is_internal_name(n) && !is_join_point_name(n))
|
||||
n = next_name();
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), n, new_type, new_val);
|
||||
let_fvars.push_back(new_fvar);
|
||||
m_fvars.push_back(new_fvar);
|
||||
|
|
|
|||
|
|
@ -238,23 +238,26 @@ class preprocess_fn {
|
|||
name n = get_real_name(d.get_name());
|
||||
// timeit timer(std::cout, (sstream() << "compiling " << n).str().c_str(), 0.05);
|
||||
expr v = unfold_aux_match(m_env, d.get_value());
|
||||
expr v1 = to_lcnf(m_env, local_ctx(), v);
|
||||
lean_trace(name({"compiler", "lcnf"}), tout() << n << "\n" << v1 << "\n";);
|
||||
lean_cond_assert("compiler", check(d, v1));
|
||||
expr v2 = csimp(m_env, local_ctx(), v1);
|
||||
lean_cond_assert("compiler", check(d, v2));
|
||||
lean_trace(name({"compiler", "simp"}), tout() << "\n" << v2 << "\n";);
|
||||
expr v3 = elim_dead_let(v2);
|
||||
lean_trace(name({"compiler", "elim_dead_let"}), tout() << "\n" << v3 << "\n";);
|
||||
lean_cond_assert("compiler", check(d, v3));
|
||||
expr v4 = cse(m_env, v3);
|
||||
lean_trace(name({"compiler", "cse"}), tout() << "\n" << v4 << "\n";);
|
||||
lean_cond_assert("compiler", check(d, v4));
|
||||
v = to_lcnf(m_env, local_ctx(), v);
|
||||
lean_trace(name({"compiler", "lcnf"}), tout() << n << "\n" << v << "\n";);
|
||||
lean_cond_assert("compiler", check(d, v));
|
||||
v = cce(m_env, local_ctx(), v);
|
||||
lean_trace(name({"compiler", "cce"}), tout() << n << "\n" << v << "\n";);
|
||||
lean_cond_assert("compiler", check(d, v));
|
||||
v = csimp(m_env, local_ctx(), v);
|
||||
lean_cond_assert("compiler", check(d, v));
|
||||
lean_trace(name({"compiler", "simp"}), tout() << "\n" << v << "\n";);
|
||||
v = elim_dead_let(v);
|
||||
lean_trace(name({"compiler", "elim_dead_let"}), tout() << "\n" << v << "\n";);
|
||||
lean_cond_assert("compiler", check(d, v));
|
||||
v = cse(m_env, v);
|
||||
lean_trace(name({"compiler", "cse"}), tout() << "\n" << v << "\n";);
|
||||
lean_cond_assert("compiler", check(d, v));
|
||||
// std::cout << "done compiling " << n << "\n";
|
||||
v4 = max_sharing(v4);
|
||||
lean_trace(name({"compiler", "stage1"}), tout() << n << "\n" << v4 << "\n";);
|
||||
v = max_sharing(v);
|
||||
lean_trace(name({"compiler", "stage1"}), tout() << n << "\n" << v << "\n";);
|
||||
declaration simp_decl = mk_definition(mk_cstage1_name(n), d.get_lparams(), d.get_type(),
|
||||
v4, reducibility_hints::mk_opaque(), true);
|
||||
v, reducibility_hints::mk_opaque(), true);
|
||||
/* IMPORTANT: We do not need to save the auxiliary declaration in the environment.
|
||||
This is just a temporary hack.
|
||||
We should store this information in a different place. In the meantime,
|
||||
|
|
@ -349,6 +352,7 @@ void initialize_preprocess() {
|
|||
register_trace_class("compiler");
|
||||
register_trace_class({"compiler", "input"});
|
||||
register_trace_class({"compiler", "lcnf"});
|
||||
register_trace_class({"compiler", "cce"});
|
||||
register_trace_class({"compiler", "simp"});
|
||||
register_trace_class({"compiler", "stage1"});
|
||||
register_trace_class({"compiler", "expand_aux"});
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include <string>
|
||||
#include "kernel/type_checker.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "library/attribute_manager.h"
|
||||
|
|
@ -80,4 +81,8 @@ expr mk_lc_unreachable(type_checker::state & s, local_ctx const & lctx, expr con
|
|||
level lvl = sort_level(tc.ensure_type(type));
|
||||
return mk_app(mk_constant(get_lc_unreachable_name(), {lvl}), type);
|
||||
}
|
||||
|
||||
bool is_join_point_name(name const & n) {
|
||||
return !n.is_atomic() && n.is_string() && strncmp(n.get_string().data(), "_join", 5) == 0;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,6 +39,9 @@ inline bool is_lc_cast_app(expr const & e) { return is_app_of(e, get_lc_cast_nam
|
|||
|
||||
expr mk_lc_unreachable(type_checker::state & s, local_ctx const & lctx, expr const & type);
|
||||
|
||||
inline name mk_join_point_name(name const & n) { return name(n, "_join"); }
|
||||
bool is_join_point_name(name const & n);
|
||||
|
||||
/* Create an auxiliary names for a declaration that saves the result of the compilation
|
||||
after step simplification. */
|
||||
inline name mk_cstage1_name(name const & decl_name) {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue