From ff2e28e5574fc037b277cae81806d43d9c4527ec Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 20 Sep 2018 21:30:08 -0700 Subject: [PATCH] feat(library/compiler): add `cce`: common case elimination --- src/library/compiler/cse.cpp | 247 ++++++++++++++++++++++++++++ src/library/compiler/cse.h | 3 + src/library/compiler/csimp.cpp | 7 +- src/library/compiler/preprocess.cpp | 34 ++-- src/library/compiler/util.cpp | 5 + src/library/compiler/util.h | 3 + 6 files changed, 282 insertions(+), 17 deletions(-) diff --git a/src/library/compiler/cse.cpp b/src/library/compiler/cse.cpp index b603ec616e..5c9f8b3cc6 100644 --- a/src/library/compiler/cse.cpp +++ b/src/library/compiler/cse.cpp @@ -6,11 +6,15 @@ Author: Leonardo de Moura */ #include #include +#include "runtime/flet.h" #include "util/name_generator.h" #include "kernel/environment.h" #include "kernel/instantiate.h" #include "kernel/abstract.h" +#include "kernel/for_each_fn.h" +#include "kernel/replace_fn.h" #include "kernel/expr_maps.h" +#include "kernel/expr_sets.h" #include "library/compiler/util.h" namespace lean { @@ -131,6 +135,249 @@ expr cse(environment const & env, expr const & e) { return cse_fn(env)(e); } +/* Common case elimination. + + This transformation creates join-points for identical minor premises. + This is important in code such as + ``` + def get_fn : expr -> tactic expr + | (expr.app f _) := pure f + | _ := throw "expr is not an application" + ``` + The "else"-branch is duplicated by the equation compiler for each constructor different from `expr.app`. */ +class cce_fn { + type_checker::state m_st; + local_ctx m_lctx; + buffer m_fvars; + expr_map m_cce_candidates; + buffer m_cce_targets; + name m_j; + unsigned m_next_idx{1}; +public: + environment & env() { return m_st.env(); } + + name_generator & ngen() { return m_st.ngen(); } + + unsigned get_fvar_idx(expr const & x) { + return m_lctx.get_local_decl(x).get_idx(); + } + + unsigned get_max_fvar_idx(expr const & e) { + if (!has_fvar(e)) + return 0; + unsigned r = 0; + for_each(e, [&](expr const & x, unsigned) { + if (!has_fvar(x)) return false; + if (is_fvar(x)) { + unsigned x_idx = get_fvar_idx(x); + if (x_idx > r) + r = x_idx; + } + return true; + }); + return r; + } + + expr replace_target(expr const & e, expr const & target, expr const & jmp) { + return replace(e, [&](expr const & t, unsigned) { + if (target == t) { + return some_expr(jmp); + } + return none_expr(); + }); + } + + expr mk_let_lambda(unsigned old_fvars_size, expr body, bool is_let) { + lean_assert(m_fvars.size() >= old_fvars_size); + if (m_fvars.size() == old_fvars_size) + return body; + unsigned first_var_idx; + if (old_fvars_size == 0) + first_var_idx = 0; + else + first_var_idx = get_fvar_idx(m_fvars[old_fvars_size]); + unsigned j = 0; + buffer> target_jmp_pairs; + name_set new_fvar_names; + for (unsigned i = 0; i < m_cce_targets.size(); i++) { + expr const & target = m_cce_targets[i]; + unsigned max_idx = get_max_fvar_idx(target); + if (max_idx >= first_var_idx) { + expr target_type = cheap_beta_reduce(type_checker(m_st, m_lctx).infer(target)); + expr unit = mk_unit(mk_level_one()); + expr unit_mk = mk_unit_mk(mk_level_one()); + expr new_val = ::lean::mk_lambda("u", unit, target); + expr new_type = ::lean::mk_arrow(unit, target_type); + expr new_fvar = m_lctx.mk_local_decl(ngen(), mk_join_point_name(m_j.append_after(m_next_idx)), new_type, new_val); + new_fvar_names.insert(fvar_name(new_fvar)); + expr jmp = ::lean::mk_let("_j", target_type, mk_app(new_fvar, unit_mk), mk_bvar(0)); + if (is_let) { + /* We must insert new_fvar after fvar with idx == max_idx */ + m_next_idx++; + unsigned k = old_fvars_size; + for (; k < m_fvars.size(); k++) { + expr const & fvar = m_fvars[k]; + if (get_fvar_idx(fvar) > max_idx) { + m_fvars.insert(k, new_fvar); + /* We need to save the pairs to replace the `target` on let-declarations that occurr after k */ + target_jmp_pairs.emplace_back(target, jmp); + break; + } + } + if (k == m_fvars.size()) { + m_fvars.push_back(new_fvar); + } + } else { + lean_assert(!is_let); + /* For lambda we add new free variable after lambda vars */ + m_fvars.push_back(new_fvar); + } + body = replace_target(body, target, jmp); + } else { + m_cce_targets[j] = target; + j++; + } + } + m_cce_targets.shrink(j); + if (is_let && !target_jmp_pairs.empty()) { + expr r = abstract(body, m_fvars.size() - old_fvars_size, m_fvars.data() + old_fvars_size); + unsigned i = m_fvars.size(); + while (i > old_fvars_size) { + --i; + expr fvar = m_fvars[i]; + local_decl decl = m_lctx.get_local_decl(fvar); + expr type = abstract(decl.get_type(), i - old_fvars_size, m_fvars.data() + old_fvars_size); + lean_assert(decl.get_value()); + expr val = *decl.get_value(); + if ((!new_fvar_names.contains(fvar_name(fvar))) && + (is_lambda(val) || is_cases_on_app(env(), val))) { + for (pair const & p : target_jmp_pairs) { + val = replace_target(val, p.first, p.second); + } + } + val = abstract(val, i - old_fvars_size, m_fvars.data() + old_fvars_size); + r = ::lean::mk_let(decl.get_user_name(), type, val, r); + } + m_fvars.shrink(old_fvars_size); + return r; + } else { + expr r = m_lctx.mk_lambda(m_fvars.size() - old_fvars_size, m_fvars.data() + old_fvars_size, body); + m_fvars.shrink(old_fvars_size); + return r; + } + } + + expr mk_let(unsigned old_fvars_size, expr const & body) { return mk_let_lambda(old_fvars_size, body, true); } + + expr mk_lambda(unsigned old_fvars_size, expr const & body) { return mk_let_lambda(old_fvars_size, body, false); } + + expr visit_let(expr e) { + buffer let_fvars; + while (is_let(e)) { + expr new_type = instantiate_rev(let_type(e), let_fvars.size(), let_fvars.data()); + expr new_val = visit_let_value(instantiate_rev(let_value(e), let_fvars.size(), let_fvars.data())); + expr new_fvar = m_lctx.mk_local_decl(ngen(), let_name(e), new_type, new_val); + let_fvars.push_back(new_fvar); + m_fvars.push_back(new_fvar); + e = let_body(e); + } + return instantiate_rev(e, let_fvars.size(), let_fvars.data()); + } + + expr visit_lambda(expr e) { + lean_assert(is_lambda(e)); + flet save_lctx(m_lctx, m_lctx); + unsigned fvars_sz1 = m_fvars.size(); + while (is_lambda(e)) { + /* Types are ignored in compilation steps. So, we do not invoke visit for d. */ + expr new_d = instantiate_rev(binding_domain(e), m_fvars.size() - fvars_sz1, m_fvars.data() + fvars_sz1); + expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), new_d, binding_info(e)); + m_fvars.push_back(new_fvar); + e = binding_body(e); + } + unsigned fvars_sz2 = m_fvars.size(); + expr new_body = visit(instantiate_rev(e, m_fvars.size() - fvars_sz1, m_fvars.data() + fvars_sz1)); + new_body = mk_let(fvars_sz2, new_body); + return mk_lambda(fvars_sz1, new_body); + } + + void add_candidate(expr const & e) { + auto it = m_cce_candidates.find(e); + if (it == m_cce_candidates.end()) { + m_cce_candidates.insert(mk_pair(e, true)); + } else if (it->second) { + m_cce_targets.push_back(e); + it->second = false; + } + } + + expr visit_app(expr const & e) { + if (!is_cases_on_app(env(), e)) return e; + buffer args; + expr const & c = get_app_args(e, args); + lean_assert(is_constant(c)); + inductive_val I_val = env().get(const_name(c).get_prefix()).to_inductive_val(); + unsigned motive_idx = I_val.get_nparams(); + unsigned first_index = motive_idx + 1; + unsigned nindices = I_val.get_nindices(); + unsigned major_idx = first_index + nindices; + unsigned first_minor_idx = major_idx + 1; + unsigned nminors = length(I_val.get_cnstrs()); + /* visit minor premises */ + for (unsigned i = 0; i < nminors; i++) { + unsigned minor_idx = first_minor_idx + i; + expr minor = args[minor_idx]; + flet save_lctx(m_lctx, m_lctx); + unsigned fvars_sz1 = m_fvars.size(); + while (is_lambda(minor)) { + expr new_d = instantiate_rev(binding_domain(minor), m_fvars.size() - fvars_sz1, m_fvars.data() + fvars_sz1); + expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(minor), new_d, binding_info(minor)); + m_fvars.push_back(new_fvar); + minor = binding_body(minor); + } + bool is_cce_target = !has_loose_bvars(minor); + unsigned fvars_sz2 = m_fvars.size(); + expr new_minor = visit(instantiate_rev(minor, m_fvars.size() - fvars_sz1, m_fvars.data() + fvars_sz1)); + new_minor = mk_let(fvars_sz2, new_minor); + if (is_cce_target && !is_lcnf_atom(new_minor)) + add_candidate(new_minor); + new_minor = mk_lambda(fvars_sz1, new_minor); + args[minor_idx] = new_minor; + } + return mk_app(c, args); + } + + expr visit_let_value(expr const & e) { + switch (e.kind()) { + case expr_kind::Lambda: return visit_lambda(e); + case expr_kind::App: return visit_app(e); + default: return e; + } + } + + expr visit(expr const & e) { + switch (e.kind()) { + case expr_kind::Lambda: return visit_lambda(e); + case expr_kind::Let: return visit_let(e); + default: return e; + } + } + +public: + cce_fn(environment const & env, local_ctx const & lctx): + m_st(env), m_lctx(lctx), m_j("_j") { + } + + expr operator()(expr const & e) { + expr r = visit(e); + return mk_let(0, r); + } +}; + +expr cce(environment const & env, local_ctx const & lctx, expr const & e) { + return cce_fn(env, lctx)(e); +} + void initialize_cse() { g_cse_fresh = new name("_cse_fresh"); register_name_generator_prefix(*g_cse_fresh); diff --git a/src/library/compiler/cse.h b/src/library/compiler/cse.h index 4250caf972..f3491df868 100644 --- a/src/library/compiler/cse.h +++ b/src/library/compiler/cse.h @@ -7,7 +7,10 @@ Author: Leonardo de Moura #pragma once #include "kernel/environment.h" namespace lean { +/* Common subexpression elimination */ expr cse(environment const & env, expr const & e); +/* Common case elimination */ +expr cce(environment const & env, local_ctx const & lctx, expr const & e); void initialize_cse(); void finalize_cse(); } diff --git a/src/library/compiler/csimp.cpp b/src/library/compiler/csimp.cpp index f9cc0c2516..615f3bcec6 100644 --- a/src/library/compiler/csimp.cpp +++ b/src/library/compiler/csimp.cpp @@ -32,7 +32,8 @@ class csimp_fn { if (is_fvar(e)) { if (optional decl = m_lctx.find_local_decl(e)) { if (optional v = decl->get_value()) - return find(*v, skip_mdata); + if (!is_join_point_name(decl->get_user_name())) + return find(*v, skip_mdata); } } else if (is_mdata(e) && skip_mdata) { return find(mdata_expr(e), true); @@ -131,7 +132,9 @@ class csimp_fn { if (is_lcnf_atom(new_val)) { let_fvars.push_back(new_val); } else { - name n = is_internal_name(let_name(e)) ? next_name() : let_name(e); + name n = let_name(e); + if (is_internal_name(n) && !is_join_point_name(n)) + n = next_name(); expr new_fvar = m_lctx.mk_local_decl(ngen(), n, new_type, new_val); let_fvars.push_back(new_fvar); m_fvars.push_back(new_fvar); diff --git a/src/library/compiler/preprocess.cpp b/src/library/compiler/preprocess.cpp index 34a1f64787..d42f487c28 100644 --- a/src/library/compiler/preprocess.cpp +++ b/src/library/compiler/preprocess.cpp @@ -238,23 +238,26 @@ class preprocess_fn { name n = get_real_name(d.get_name()); // timeit timer(std::cout, (sstream() << "compiling " << n).str().c_str(), 0.05); expr v = unfold_aux_match(m_env, d.get_value()); - expr v1 = to_lcnf(m_env, local_ctx(), v); - lean_trace(name({"compiler", "lcnf"}), tout() << n << "\n" << v1 << "\n";); - lean_cond_assert("compiler", check(d, v1)); - expr v2 = csimp(m_env, local_ctx(), v1); - lean_cond_assert("compiler", check(d, v2)); - lean_trace(name({"compiler", "simp"}), tout() << "\n" << v2 << "\n";); - expr v3 = elim_dead_let(v2); - lean_trace(name({"compiler", "elim_dead_let"}), tout() << "\n" << v3 << "\n";); - lean_cond_assert("compiler", check(d, v3)); - expr v4 = cse(m_env, v3); - lean_trace(name({"compiler", "cse"}), tout() << "\n" << v4 << "\n";); - lean_cond_assert("compiler", check(d, v4)); + v = to_lcnf(m_env, local_ctx(), v); + lean_trace(name({"compiler", "lcnf"}), tout() << n << "\n" << v << "\n";); + lean_cond_assert("compiler", check(d, v)); + v = cce(m_env, local_ctx(), v); + lean_trace(name({"compiler", "cce"}), tout() << n << "\n" << v << "\n";); + lean_cond_assert("compiler", check(d, v)); + v = csimp(m_env, local_ctx(), v); + lean_cond_assert("compiler", check(d, v)); + lean_trace(name({"compiler", "simp"}), tout() << "\n" << v << "\n";); + v = elim_dead_let(v); + lean_trace(name({"compiler", "elim_dead_let"}), tout() << "\n" << v << "\n";); + lean_cond_assert("compiler", check(d, v)); + v = cse(m_env, v); + lean_trace(name({"compiler", "cse"}), tout() << "\n" << v << "\n";); + lean_cond_assert("compiler", check(d, v)); // std::cout << "done compiling " << n << "\n"; - v4 = max_sharing(v4); - lean_trace(name({"compiler", "stage1"}), tout() << n << "\n" << v4 << "\n";); + v = max_sharing(v); + lean_trace(name({"compiler", "stage1"}), tout() << n << "\n" << v << "\n";); declaration simp_decl = mk_definition(mk_cstage1_name(n), d.get_lparams(), d.get_type(), - v4, reducibility_hints::mk_opaque(), true); + v, reducibility_hints::mk_opaque(), true); /* IMPORTANT: We do not need to save the auxiliary declaration in the environment. This is just a temporary hack. We should store this information in a different place. In the meantime, @@ -349,6 +352,7 @@ void initialize_preprocess() { register_trace_class("compiler"); register_trace_class({"compiler", "input"}); register_trace_class({"compiler", "lcnf"}); + register_trace_class({"compiler", "cce"}); register_trace_class({"compiler", "simp"}); register_trace_class({"compiler", "stage1"}); register_trace_class({"compiler", "expand_aux"}); diff --git a/src/library/compiler/util.cpp b/src/library/compiler/util.cpp index 22b9d2e57d..75b32fdcdf 100644 --- a/src/library/compiler/util.cpp +++ b/src/library/compiler/util.cpp @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ +#include #include "kernel/type_checker.h" #include "kernel/instantiate.h" #include "library/attribute_manager.h" @@ -80,4 +81,8 @@ expr mk_lc_unreachable(type_checker::state & s, local_ctx const & lctx, expr con level lvl = sort_level(tc.ensure_type(type)); return mk_app(mk_constant(get_lc_unreachable_name(), {lvl}), type); } + +bool is_join_point_name(name const & n) { + return !n.is_atomic() && n.is_string() && strncmp(n.get_string().data(), "_join", 5) == 0; +} } diff --git a/src/library/compiler/util.h b/src/library/compiler/util.h index 8c16d914a0..a089744493 100644 --- a/src/library/compiler/util.h +++ b/src/library/compiler/util.h @@ -39,6 +39,9 @@ inline bool is_lc_cast_app(expr const & e) { return is_app_of(e, get_lc_cast_nam expr mk_lc_unreachable(type_checker::state & s, local_ctx const & lctx, expr const & type); +inline name mk_join_point_name(name const & n) { return name(n, "_join"); } +bool is_join_point_name(name const & n); + /* Create an auxiliary names for a declaration that saves the result of the compilation after step simplification. */ inline name mk_cstage1_name(name const & decl_name) {