feat(library/compiler): use eta expansion at eager_lambda_lifting

This commit is contained in:
Leonardo de Moura 2019-07-09 16:34:20 -07:00
parent bda0277468
commit 85d151a335
6 changed files with 63 additions and 42 deletions

View file

@ -198,7 +198,7 @@ environment compile(environment const & env, options const & opts, names cs) {
ds = apply(simp, env, ds);
trace_compiler(name({"compiler", "simp"}), ds);
environment new_env = env;
std::tie(new_env, ds) = eager_lambda_lifting(new_env, ds);
std::tie(new_env, ds) = eager_lambda_lifting(new_env, ds, cfg);
trace_compiler(name({"compiler", "eager_lambda_lifting"}), ds);
ds = apply(max_sharing, ds);
trace_compiler(name({"compiler", "stage1"}), ds);

View file

@ -13,6 +13,7 @@ Author: Leonardo de Moura
#include "library/module.h"
#include "library/class.h"
#include "library/compiler/util.h"
#include "library/compiler/csimp.h"
#include "library/compiler/closed_term_cache.h"
namespace lean {
@ -98,6 +99,7 @@ static bool depends_on_fvar(local_ctx const & lctx, buffer<expr> const & params,
*/
class eager_lambda_lifting_fn {
type_checker::state m_st;
csimp_cfg m_cfg;
local_ctx m_lctx;
buffer<comp_decl> m_new_decls;
name m_base_name;
@ -110,6 +112,10 @@ class eager_lambda_lifting_fn {
name_generator & ngen() { return m_st.ngen(); }
expr eta_expand(expr const & e) {
return lcnf_eta_expand(m_st, m_lctx, e);
}
name next_name() {
name r = mk_elambda_lifting_name(m_base_name, m_next_idx);
m_next_idx++;
@ -183,7 +189,7 @@ class eager_lambda_lifting_fn {
}
}
expr lift_lambda(expr e) {
expr lift_lambda(expr e, bool apply_simp) {
lean_assert(is_lambda(e));
buffer<expr> fvars;
if (!collect_fvars(e, fvars)) {
@ -216,6 +222,9 @@ class eager_lambda_lifting_fn {
expr type = abstract(decl.get_type(), i, new_params.data());
code = ::lean::mk_lambda(decl.get_user_name(), type, code);
}
if (apply_simp) {
code = csimp(env(), code, m_cfg);
}
expr type = cheap_beta_reduce(type_checker(m_st).infer(code));
name n = next_name();
/* We add the auxiliary declaration `n` as a "meta" axiom to the environment.
@ -298,7 +307,10 @@ class eager_lambda_lifting_fn {
expr type = abstract(decl.get_type(), i, fvars.data());
expr val = *decl.get_value();
if (m_terminal_lambdas.contains(n) && !m_nonterminal_lambdas.contains(n)) {
val = lift_lambda(val);
expr new_val = eta_expand(val);
lean_assert(is_lambda(new_val));
bool apply_simp = new_val != val;
val = lift_lambda(new_val, apply_simp);
}
r = ::lean::mk_let(decl.get_user_name(), type, abstract(val, i, fvars.data()), r);
}
@ -363,6 +375,7 @@ class eager_lambda_lifting_fn {
if (is_fvar(arg)) {
name x; expr v;
std::tie(x, v) = find(arg);
v = eta_expand(v);
if (is_lambda(v)) {
m_terminal_lambdas.insert(x);
}
@ -375,8 +388,8 @@ class eager_lambda_lifting_fn {
}
public:
eager_lambda_lifting_fn(environment const & env):
m_st(env) {}
eager_lambda_lifting_fn(environment const & env, csimp_cfg const & cfg):
m_st(env), m_cfg(cfg) {}
pair<environment, comp_decls> operator()(comp_decl const & cdecl) {
m_base_name = cdecl.fst();
@ -387,14 +400,14 @@ public:
}
};
pair<environment, comp_decls> eager_lambda_lifting(environment env, comp_decls const & ds) {
pair<environment, comp_decls> eager_lambda_lifting(environment env, comp_decls const & ds, csimp_cfg const & cfg) {
comp_decls r;
for (comp_decl const & d : ds) {
if (has_inline_attribute(env, d.fst()) || is_instance(env, d.fst())) {
r = append(r, comp_decls(d));
} else {
comp_decls new_ds;
std::tie(env, new_ds) = eager_lambda_lifting_fn(env)(d);
std::tie(env, new_ds) = eager_lambda_lifting_fn(env, cfg)(d);
r = append(r, new_ds);
}
}

View file

@ -6,10 +6,10 @@ Author: Leonardo de Moura
*/
#pragma once
#include "kernel/environment.h"
#include "library/compiler/util.h"
#include "library/compiler/csimp.h"
namespace lean {
/** \brief Eager version of lambda lifting. See comment at eager_lambda_lifting.cpp. */
pair<environment, comp_decls> eager_lambda_lifting(environment env, comp_decls const & ds);
pair<environment, comp_decls> eager_lambda_lifting(environment env, comp_decls const & ds, csimp_cfg const & cfg);
/* Return true iff `fn` is the name of an auxiliary function generated by `eager_lambda_lifting`. */
bool is_elambda_lifting_name(name fn);
};

View file

@ -766,39 +766,7 @@ class specialize_fn {
}
expr eta_expand_specialization(expr e) {
/* Remark: we do not use `type_checker.eta_expand` because it does not preserve LCNF */
try {
buffer<expr> args;
type_checker tc(m_st);
expr e_type = tc.whnf(tc.infer(e));
local_ctx lctx;
while (is_pi(e_type)) {
expr arg = lctx.mk_local_decl(ngen(), binding_name(e_type), binding_domain(e_type), binding_info(e_type));
args.push_back(arg);
e_type = type_checker(m_st, lctx).whnf(instantiate(binding_body(e_type), arg));
}
if (args.empty())
return e;
buffer<expr> fvars;
while (is_let(e)) {
expr type = instantiate_rev(let_type(e), fvars.size(), fvars.data());
expr val = instantiate_rev(let_value(e), fvars.size(), fvars.data());
expr fvar = lctx.mk_local_decl(ngen(), let_name(e), type, val);
fvars.push_back(fvar);
e = let_body(e);
}
e = instantiate_rev(e, fvars.size(), fvars.data());
if (!is_lcnf_atom(e)) {
e = lctx.mk_local_decl(ngen(), "_e", type_checker(m_st, lctx).infer(e), e);
fvars.push_back(e);
}
e = mk_app(e, args);
return lctx.mk_lambda(args, lctx.mk_lambda(fvars, e));
} catch (exception &) {
/* This can happen since previous compilation steps may have
produced type incorrect terms. */
return e;
}
return lcnf_eta_expand(m_st, local_ctx(), e);
}
expr abstract_spec_ctx(spec_ctx const & ctx, expr const & code) {

View file

@ -634,6 +634,42 @@ optional<unsigned> is_enum_type(environment const & env, expr const & type) {
// =======================================
expr lcnf_eta_expand(type_checker::state & st, local_ctx lctx, expr e) {
/* Remark: we do not use `type_checker.eta_expand` because it does not preserve LCNF */
try {
buffer<expr> args;
type_checker tc(st, lctx);
expr e_type = tc.whnf(tc.infer(e));
while (is_pi(e_type)) {
expr arg = lctx.mk_local_decl(st.ngen(), binding_name(e_type), binding_domain(e_type), binding_info(e_type));
args.push_back(arg);
e_type = type_checker(st, lctx).whnf(instantiate(binding_body(e_type), arg));
}
if (args.empty())
return e;
buffer<expr> fvars;
while (is_let(e)) {
expr type = instantiate_rev(let_type(e), fvars.size(), fvars.data());
expr val = instantiate_rev(let_value(e), fvars.size(), fvars.data());
expr fvar = lctx.mk_local_decl(st.ngen(), let_name(e), type, val);
fvars.push_back(fvar);
e = let_body(e);
}
e = instantiate_rev(e, fvars.size(), fvars.data());
if (!is_lcnf_atom(e)) {
e = lctx.mk_local_decl(st.ngen(), "_e", type_checker(st, lctx).infer(e), e);
fvars.push_back(e);
}
e = mk_app(e, args);
return lctx.mk_lambda(args, lctx.mk_lambda(fvars, e));
} catch (exception &) {
/* This can happen since previous compilation steps may have
produced type incorrect terms. */
return e;
}
}
void initialize_compiler_util() {
g_neutral_expr = new expr(mk_constant("_neutral"));
g_unreachable_expr = new expr(mk_constant("_unreachable"));

View file

@ -173,6 +173,10 @@ optional<expr> mk_enf_fix_core(unsigned n);
bool lcnf_check_let_decls(environment const & env, comp_decl const & d);
bool lcnf_check_let_decls(environment const & env, comp_decls const & ds);
// =======================================
/* Similar to `type_checker::eta_expand`, but preserves LCNF */
expr lcnf_eta_expand(type_checker::state & st, local_ctx lctx, expr e);
// =======================================
// UInt and USize helper functions