feat(library/compiler): use eta expansion at eager_lambda_lifting
This commit is contained in:
parent
bda0277468
commit
85d151a335
6 changed files with 63 additions and 42 deletions
|
|
@ -198,7 +198,7 @@ environment compile(environment const & env, options const & opts, names cs) {
|
|||
ds = apply(simp, env, ds);
|
||||
trace_compiler(name({"compiler", "simp"}), ds);
|
||||
environment new_env = env;
|
||||
std::tie(new_env, ds) = eager_lambda_lifting(new_env, ds);
|
||||
std::tie(new_env, ds) = eager_lambda_lifting(new_env, ds, cfg);
|
||||
trace_compiler(name({"compiler", "eager_lambda_lifting"}), ds);
|
||||
ds = apply(max_sharing, ds);
|
||||
trace_compiler(name({"compiler", "stage1"}), ds);
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ Author: Leonardo de Moura
|
|||
#include "library/module.h"
|
||||
#include "library/class.h"
|
||||
#include "library/compiler/util.h"
|
||||
#include "library/compiler/csimp.h"
|
||||
#include "library/compiler/closed_term_cache.h"
|
||||
|
||||
namespace lean {
|
||||
|
|
@ -98,6 +99,7 @@ static bool depends_on_fvar(local_ctx const & lctx, buffer<expr> const & params,
|
|||
*/
|
||||
class eager_lambda_lifting_fn {
|
||||
type_checker::state m_st;
|
||||
csimp_cfg m_cfg;
|
||||
local_ctx m_lctx;
|
||||
buffer<comp_decl> m_new_decls;
|
||||
name m_base_name;
|
||||
|
|
@ -110,6 +112,10 @@ class eager_lambda_lifting_fn {
|
|||
|
||||
name_generator & ngen() { return m_st.ngen(); }
|
||||
|
||||
expr eta_expand(expr const & e) {
|
||||
return lcnf_eta_expand(m_st, m_lctx, e);
|
||||
}
|
||||
|
||||
name next_name() {
|
||||
name r = mk_elambda_lifting_name(m_base_name, m_next_idx);
|
||||
m_next_idx++;
|
||||
|
|
@ -183,7 +189,7 @@ class eager_lambda_lifting_fn {
|
|||
}
|
||||
}
|
||||
|
||||
expr lift_lambda(expr e) {
|
||||
expr lift_lambda(expr e, bool apply_simp) {
|
||||
lean_assert(is_lambda(e));
|
||||
buffer<expr> fvars;
|
||||
if (!collect_fvars(e, fvars)) {
|
||||
|
|
@ -216,6 +222,9 @@ class eager_lambda_lifting_fn {
|
|||
expr type = abstract(decl.get_type(), i, new_params.data());
|
||||
code = ::lean::mk_lambda(decl.get_user_name(), type, code);
|
||||
}
|
||||
if (apply_simp) {
|
||||
code = csimp(env(), code, m_cfg);
|
||||
}
|
||||
expr type = cheap_beta_reduce(type_checker(m_st).infer(code));
|
||||
name n = next_name();
|
||||
/* We add the auxiliary declaration `n` as a "meta" axiom to the environment.
|
||||
|
|
@ -298,7 +307,10 @@ class eager_lambda_lifting_fn {
|
|||
expr type = abstract(decl.get_type(), i, fvars.data());
|
||||
expr val = *decl.get_value();
|
||||
if (m_terminal_lambdas.contains(n) && !m_nonterminal_lambdas.contains(n)) {
|
||||
val = lift_lambda(val);
|
||||
expr new_val = eta_expand(val);
|
||||
lean_assert(is_lambda(new_val));
|
||||
bool apply_simp = new_val != val;
|
||||
val = lift_lambda(new_val, apply_simp);
|
||||
}
|
||||
r = ::lean::mk_let(decl.get_user_name(), type, abstract(val, i, fvars.data()), r);
|
||||
}
|
||||
|
|
@ -363,6 +375,7 @@ class eager_lambda_lifting_fn {
|
|||
if (is_fvar(arg)) {
|
||||
name x; expr v;
|
||||
std::tie(x, v) = find(arg);
|
||||
v = eta_expand(v);
|
||||
if (is_lambda(v)) {
|
||||
m_terminal_lambdas.insert(x);
|
||||
}
|
||||
|
|
@ -375,8 +388,8 @@ class eager_lambda_lifting_fn {
|
|||
}
|
||||
|
||||
public:
|
||||
eager_lambda_lifting_fn(environment const & env):
|
||||
m_st(env) {}
|
||||
eager_lambda_lifting_fn(environment const & env, csimp_cfg const & cfg):
|
||||
m_st(env), m_cfg(cfg) {}
|
||||
|
||||
pair<environment, comp_decls> operator()(comp_decl const & cdecl) {
|
||||
m_base_name = cdecl.fst();
|
||||
|
|
@ -387,14 +400,14 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
pair<environment, comp_decls> eager_lambda_lifting(environment env, comp_decls const & ds) {
|
||||
pair<environment, comp_decls> eager_lambda_lifting(environment env, comp_decls const & ds, csimp_cfg const & cfg) {
|
||||
comp_decls r;
|
||||
for (comp_decl const & d : ds) {
|
||||
if (has_inline_attribute(env, d.fst()) || is_instance(env, d.fst())) {
|
||||
r = append(r, comp_decls(d));
|
||||
} else {
|
||||
comp_decls new_ds;
|
||||
std::tie(env, new_ds) = eager_lambda_lifting_fn(env)(d);
|
||||
std::tie(env, new_ds) = eager_lambda_lifting_fn(env, cfg)(d);
|
||||
r = append(r, new_ds);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ Author: Leonardo de Moura
|
|||
*/
|
||||
#pragma once
|
||||
#include "kernel/environment.h"
|
||||
#include "library/compiler/util.h"
|
||||
#include "library/compiler/csimp.h"
|
||||
namespace lean {
|
||||
/** \brief Eager version of lambda lifting. See comment at eager_lambda_lifting.cpp. */
|
||||
pair<environment, comp_decls> eager_lambda_lifting(environment env, comp_decls const & ds);
|
||||
pair<environment, comp_decls> eager_lambda_lifting(environment env, comp_decls const & ds, csimp_cfg const & cfg);
|
||||
/* Return true iff `fn` is the name of an auxiliary function generated by `eager_lambda_lifting`. */
|
||||
bool is_elambda_lifting_name(name fn);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -766,39 +766,7 @@ class specialize_fn {
|
|||
}
|
||||
|
||||
expr eta_expand_specialization(expr e) {
|
||||
/* Remark: we do not use `type_checker.eta_expand` because it does not preserve LCNF */
|
||||
try {
|
||||
buffer<expr> args;
|
||||
type_checker tc(m_st);
|
||||
expr e_type = tc.whnf(tc.infer(e));
|
||||
local_ctx lctx;
|
||||
while (is_pi(e_type)) {
|
||||
expr arg = lctx.mk_local_decl(ngen(), binding_name(e_type), binding_domain(e_type), binding_info(e_type));
|
||||
args.push_back(arg);
|
||||
e_type = type_checker(m_st, lctx).whnf(instantiate(binding_body(e_type), arg));
|
||||
}
|
||||
if (args.empty())
|
||||
return e;
|
||||
buffer<expr> fvars;
|
||||
while (is_let(e)) {
|
||||
expr type = instantiate_rev(let_type(e), fvars.size(), fvars.data());
|
||||
expr val = instantiate_rev(let_value(e), fvars.size(), fvars.data());
|
||||
expr fvar = lctx.mk_local_decl(ngen(), let_name(e), type, val);
|
||||
fvars.push_back(fvar);
|
||||
e = let_body(e);
|
||||
}
|
||||
e = instantiate_rev(e, fvars.size(), fvars.data());
|
||||
if (!is_lcnf_atom(e)) {
|
||||
e = lctx.mk_local_decl(ngen(), "_e", type_checker(m_st, lctx).infer(e), e);
|
||||
fvars.push_back(e);
|
||||
}
|
||||
e = mk_app(e, args);
|
||||
return lctx.mk_lambda(args, lctx.mk_lambda(fvars, e));
|
||||
} catch (exception &) {
|
||||
/* This can happen since previous compilation steps may have
|
||||
produced type incorrect terms. */
|
||||
return e;
|
||||
}
|
||||
return lcnf_eta_expand(m_st, local_ctx(), e);
|
||||
}
|
||||
|
||||
expr abstract_spec_ctx(spec_ctx const & ctx, expr const & code) {
|
||||
|
|
|
|||
|
|
@ -634,6 +634,42 @@ optional<unsigned> is_enum_type(environment const & env, expr const & type) {
|
|||
|
||||
// =======================================
|
||||
|
||||
|
||||
expr lcnf_eta_expand(type_checker::state & st, local_ctx lctx, expr e) {
|
||||
/* Remark: we do not use `type_checker.eta_expand` because it does not preserve LCNF */
|
||||
try {
|
||||
buffer<expr> args;
|
||||
type_checker tc(st, lctx);
|
||||
expr e_type = tc.whnf(tc.infer(e));
|
||||
while (is_pi(e_type)) {
|
||||
expr arg = lctx.mk_local_decl(st.ngen(), binding_name(e_type), binding_domain(e_type), binding_info(e_type));
|
||||
args.push_back(arg);
|
||||
e_type = type_checker(st, lctx).whnf(instantiate(binding_body(e_type), arg));
|
||||
}
|
||||
if (args.empty())
|
||||
return e;
|
||||
buffer<expr> fvars;
|
||||
while (is_let(e)) {
|
||||
expr type = instantiate_rev(let_type(e), fvars.size(), fvars.data());
|
||||
expr val = instantiate_rev(let_value(e), fvars.size(), fvars.data());
|
||||
expr fvar = lctx.mk_local_decl(st.ngen(), let_name(e), type, val);
|
||||
fvars.push_back(fvar);
|
||||
e = let_body(e);
|
||||
}
|
||||
e = instantiate_rev(e, fvars.size(), fvars.data());
|
||||
if (!is_lcnf_atom(e)) {
|
||||
e = lctx.mk_local_decl(st.ngen(), "_e", type_checker(st, lctx).infer(e), e);
|
||||
fvars.push_back(e);
|
||||
}
|
||||
e = mk_app(e, args);
|
||||
return lctx.mk_lambda(args, lctx.mk_lambda(fvars, e));
|
||||
} catch (exception &) {
|
||||
/* This can happen since previous compilation steps may have
|
||||
produced type incorrect terms. */
|
||||
return e;
|
||||
}
|
||||
}
|
||||
|
||||
void initialize_compiler_util() {
|
||||
g_neutral_expr = new expr(mk_constant("_neutral"));
|
||||
g_unreachable_expr = new expr(mk_constant("_unreachable"));
|
||||
|
|
|
|||
|
|
@ -173,6 +173,10 @@ optional<expr> mk_enf_fix_core(unsigned n);
|
|||
bool lcnf_check_let_decls(environment const & env, comp_decl const & d);
|
||||
bool lcnf_check_let_decls(environment const & env, comp_decls const & ds);
|
||||
|
||||
// =======================================
|
||||
/* Similar to `type_checker::eta_expand`, but preserves LCNF */
|
||||
expr lcnf_eta_expand(type_checker::state & st, local_ctx lctx, expr e);
|
||||
|
||||
// =======================================
|
||||
// UInt and USize helper functions
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue