From 85d151a335a14ef48130e7866f92febed9c355fc Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 9 Jul 2019 16:34:20 -0700 Subject: [PATCH] feat(library/compiler): use eta expansion at `eager_lambda_lifting` --- src/library/compiler/compiler.cpp | 2 +- src/library/compiler/eager_lambda_lifting.cpp | 25 +++++++++---- src/library/compiler/eager_lambda_lifting.h | 4 +-- src/library/compiler/specialize.cpp | 34 +----------------- src/library/compiler/util.cpp | 36 +++++++++++++++++++ src/library/compiler/util.h | 4 +++ 6 files changed, 63 insertions(+), 42 deletions(-) diff --git a/src/library/compiler/compiler.cpp b/src/library/compiler/compiler.cpp index ad8f273ab5..06ea50f9eb 100644 --- a/src/library/compiler/compiler.cpp +++ b/src/library/compiler/compiler.cpp @@ -198,7 +198,7 @@ environment compile(environment const & env, options const & opts, names cs) { ds = apply(simp, env, ds); trace_compiler(name({"compiler", "simp"}), ds); environment new_env = env; - std::tie(new_env, ds) = eager_lambda_lifting(new_env, ds); + std::tie(new_env, ds) = eager_lambda_lifting(new_env, ds, cfg); trace_compiler(name({"compiler", "eager_lambda_lifting"}), ds); ds = apply(max_sharing, ds); trace_compiler(name({"compiler", "stage1"}), ds); diff --git a/src/library/compiler/eager_lambda_lifting.cpp b/src/library/compiler/eager_lambda_lifting.cpp index 84e8585c36..59bdfdd56a 100644 --- a/src/library/compiler/eager_lambda_lifting.cpp +++ b/src/library/compiler/eager_lambda_lifting.cpp @@ -13,6 +13,7 @@ Author: Leonardo de Moura #include "library/module.h" #include "library/class.h" #include "library/compiler/util.h" +#include "library/compiler/csimp.h" #include "library/compiler/closed_term_cache.h" namespace lean { @@ -98,6 +99,7 @@ static bool depends_on_fvar(local_ctx const & lctx, buffer const & params, */ class eager_lambda_lifting_fn { type_checker::state m_st; + csimp_cfg m_cfg; local_ctx m_lctx; buffer m_new_decls; name m_base_name; @@ -110,6 +112,10 @@ class eager_lambda_lifting_fn { name_generator & ngen() { return m_st.ngen(); } + expr eta_expand(expr const & e) { + return lcnf_eta_expand(m_st, m_lctx, e); + } + name next_name() { name r = mk_elambda_lifting_name(m_base_name, m_next_idx); m_next_idx++; @@ -183,7 +189,7 @@ class eager_lambda_lifting_fn { } } - expr lift_lambda(expr e) { + expr lift_lambda(expr e, bool apply_simp) { lean_assert(is_lambda(e)); buffer fvars; if (!collect_fvars(e, fvars)) { @@ -216,6 +222,9 @@ class eager_lambda_lifting_fn { expr type = abstract(decl.get_type(), i, new_params.data()); code = ::lean::mk_lambda(decl.get_user_name(), type, code); } + if (apply_simp) { + code = csimp(env(), code, m_cfg); + } expr type = cheap_beta_reduce(type_checker(m_st).infer(code)); name n = next_name(); /* We add the auxiliary declaration `n` as a "meta" axiom to the environment. @@ -298,7 +307,10 @@ class eager_lambda_lifting_fn { expr type = abstract(decl.get_type(), i, fvars.data()); expr val = *decl.get_value(); if (m_terminal_lambdas.contains(n) && !m_nonterminal_lambdas.contains(n)) { - val = lift_lambda(val); + expr new_val = eta_expand(val); + lean_assert(is_lambda(new_val)); + bool apply_simp = new_val != val; + val = lift_lambda(new_val, apply_simp); } r = ::lean::mk_let(decl.get_user_name(), type, abstract(val, i, fvars.data()), r); } @@ -363,6 +375,7 @@ class eager_lambda_lifting_fn { if (is_fvar(arg)) { name x; expr v; std::tie(x, v) = find(arg); + v = eta_expand(v); if (is_lambda(v)) { m_terminal_lambdas.insert(x); } @@ -375,8 +388,8 @@ class eager_lambda_lifting_fn { } public: - eager_lambda_lifting_fn(environment const & env): - m_st(env) {} + eager_lambda_lifting_fn(environment const & env, csimp_cfg const & cfg): + m_st(env), m_cfg(cfg) {} pair operator()(comp_decl const & cdecl) { m_base_name = cdecl.fst(); @@ -387,14 +400,14 @@ public: } }; -pair eager_lambda_lifting(environment env, comp_decls const & ds) { +pair eager_lambda_lifting(environment env, comp_decls const & ds, csimp_cfg const & cfg) { comp_decls r; for (comp_decl const & d : ds) { if (has_inline_attribute(env, d.fst()) || is_instance(env, d.fst())) { r = append(r, comp_decls(d)); } else { comp_decls new_ds; - std::tie(env, new_ds) = eager_lambda_lifting_fn(env)(d); + std::tie(env, new_ds) = eager_lambda_lifting_fn(env, cfg)(d); r = append(r, new_ds); } } diff --git a/src/library/compiler/eager_lambda_lifting.h b/src/library/compiler/eager_lambda_lifting.h index ce3e89960b..04ab99872d 100644 --- a/src/library/compiler/eager_lambda_lifting.h +++ b/src/library/compiler/eager_lambda_lifting.h @@ -6,10 +6,10 @@ Author: Leonardo de Moura */ #pragma once #include "kernel/environment.h" -#include "library/compiler/util.h" +#include "library/compiler/csimp.h" namespace lean { /** \brief Eager version of lambda lifting. See comment at eager_lambda_lifting.cpp. */ -pair eager_lambda_lifting(environment env, comp_decls const & ds); +pair eager_lambda_lifting(environment env, comp_decls const & ds, csimp_cfg const & cfg); /* Return true iff `fn` is the name of an auxiliary function generated by `eager_lambda_lifting`. */ bool is_elambda_lifting_name(name fn); }; diff --git a/src/library/compiler/specialize.cpp b/src/library/compiler/specialize.cpp index 992a0b68f2..911e7c5c7c 100644 --- a/src/library/compiler/specialize.cpp +++ b/src/library/compiler/specialize.cpp @@ -766,39 +766,7 @@ class specialize_fn { } expr eta_expand_specialization(expr e) { - /* Remark: we do not use `type_checker.eta_expand` because it does not preserve LCNF */ - try { - buffer args; - type_checker tc(m_st); - expr e_type = tc.whnf(tc.infer(e)); - local_ctx lctx; - while (is_pi(e_type)) { - expr arg = lctx.mk_local_decl(ngen(), binding_name(e_type), binding_domain(e_type), binding_info(e_type)); - args.push_back(arg); - e_type = type_checker(m_st, lctx).whnf(instantiate(binding_body(e_type), arg)); - } - if (args.empty()) - return e; - buffer fvars; - while (is_let(e)) { - expr type = instantiate_rev(let_type(e), fvars.size(), fvars.data()); - expr val = instantiate_rev(let_value(e), fvars.size(), fvars.data()); - expr fvar = lctx.mk_local_decl(ngen(), let_name(e), type, val); - fvars.push_back(fvar); - e = let_body(e); - } - e = instantiate_rev(e, fvars.size(), fvars.data()); - if (!is_lcnf_atom(e)) { - e = lctx.mk_local_decl(ngen(), "_e", type_checker(m_st, lctx).infer(e), e); - fvars.push_back(e); - } - e = mk_app(e, args); - return lctx.mk_lambda(args, lctx.mk_lambda(fvars, e)); - } catch (exception &) { - /* This can happen since previous compilation steps may have - produced type incorrect terms. */ - return e; - } + return lcnf_eta_expand(m_st, local_ctx(), e); } expr abstract_spec_ctx(spec_ctx const & ctx, expr const & code) { diff --git a/src/library/compiler/util.cpp b/src/library/compiler/util.cpp index 3a7e76dd32..45d0908896 100644 --- a/src/library/compiler/util.cpp +++ b/src/library/compiler/util.cpp @@ -634,6 +634,42 @@ optional is_enum_type(environment const & env, expr const & type) { // ======================================= + +expr lcnf_eta_expand(type_checker::state & st, local_ctx lctx, expr e) { + /* Remark: we do not use `type_checker.eta_expand` because it does not preserve LCNF */ + try { + buffer args; + type_checker tc(st, lctx); + expr e_type = tc.whnf(tc.infer(e)); + while (is_pi(e_type)) { + expr arg = lctx.mk_local_decl(st.ngen(), binding_name(e_type), binding_domain(e_type), binding_info(e_type)); + args.push_back(arg); + e_type = type_checker(st, lctx).whnf(instantiate(binding_body(e_type), arg)); + } + if (args.empty()) + return e; + buffer fvars; + while (is_let(e)) { + expr type = instantiate_rev(let_type(e), fvars.size(), fvars.data()); + expr val = instantiate_rev(let_value(e), fvars.size(), fvars.data()); + expr fvar = lctx.mk_local_decl(st.ngen(), let_name(e), type, val); + fvars.push_back(fvar); + e = let_body(e); + } + e = instantiate_rev(e, fvars.size(), fvars.data()); + if (!is_lcnf_atom(e)) { + e = lctx.mk_local_decl(st.ngen(), "_e", type_checker(st, lctx).infer(e), e); + fvars.push_back(e); + } + e = mk_app(e, args); + return lctx.mk_lambda(args, lctx.mk_lambda(fvars, e)); + } catch (exception &) { + /* This can happen since previous compilation steps may have + produced type incorrect terms. */ + return e; + } +} + void initialize_compiler_util() { g_neutral_expr = new expr(mk_constant("_neutral")); g_unreachable_expr = new expr(mk_constant("_unreachable")); diff --git a/src/library/compiler/util.h b/src/library/compiler/util.h index 35197a32a5..0f5d3e5e59 100644 --- a/src/library/compiler/util.h +++ b/src/library/compiler/util.h @@ -173,6 +173,10 @@ optional mk_enf_fix_core(unsigned n); bool lcnf_check_let_decls(environment const & env, comp_decl const & d); bool lcnf_check_let_decls(environment const & env, comp_decls const & ds); +// ======================================= +/* Similar to `type_checker::eta_expand`, but preserves LCNF */ +expr lcnf_eta_expand(type_checker::state & st, local_ctx lctx, expr e); + // ======================================= // UInt and USize helper functions