feat(library/compiler/llnf): new reset/reuse insertion procedure

This commit is contained in:
Leonardo de Moura 2019-02-20 16:12:58 -08:00
parent 937b947938
commit 54a89dabb7
2 changed files with 320 additions and 87 deletions

View file

@ -107,7 +107,7 @@ scalar fields.
expr mk_llnf_cnstr(name const & I, unsigned cidx, unsigned num_usizes, unsigned num_bytes) {
return mk_constant(name(name(name(name(I, g_cnstr), cidx), num_usizes), num_bytes));
}
bool is_llnf_cnstr_core(expr const & e, unsigned & cidx, unsigned & num_usizes, unsigned & num_bytes) {
bool is_llnf_cnstr(expr const & e, name & I, unsigned & cidx, unsigned & num_usizes, unsigned & num_bytes) {
if (!is_constant(e)) return false;
name const & n3 = const_name(e);
if (!is_internal_name(n3)) return false;
@ -120,10 +120,9 @@ bool is_llnf_cnstr_core(expr const & e, unsigned & cidx, unsigned & num_usizes,
if (n1.is_atomic() || !n1.is_numeral()) return false;
cidx = n1.get_numeral().get_small_value();
name const & n0 = n1.get_prefix();
return !n0.is_atomic() && n0.is_string() && n0.get_string() == g_cnstr;
}
bool is_llnf_cnstr(expr const & e, unsigned & cidx, unsigned & num_usizes, unsigned & num_bytes) {
return is_llnf_cnstr_core(e, cidx, num_usizes, num_bytes);
if (n0.is_atomic() || !n0.is_string() || n0.get_string() != g_cnstr) return false;
I = n0.get_prefix();
return true;
}
/* The `_reuse.<cidx>.<num_usizes>.<num_bytes>.<updt_cidx>` is similar to `_cnstr.<cidx>.<num_usize>.<num_bytes>`, but it takes an extra argument: a memory cell that may be reused. */
@ -286,6 +285,82 @@ static void get_borrowed_info(environment const & env, name const & n, buffer<bo
borrowed_res = false;
}
static optional<unsigned> is_enum_type(environment const & env, expr const & type) {
expr const & I = get_app_fn(type);
if (!is_constant(I)) return optional<unsigned>();
return is_enum_type(env, const_name(I));
}
static void get_cnstr_info_core(type_checker::state & st, bool unboxed, name const & n, buffer<field_info> & result) {
environment const & env = st.env();
constant_info info = env.get(n);
lean_assert(info.is_constructor());
constructor_val val = info.to_constructor_val();
expr type = info.get_type();
name I_name = val.get_induct();
unsigned nparams = val.get_nparams();
local_ctx lctx;
buffer<expr> telescope;
unsigned next_object = 0;
unsigned next_usize = 0;
unsigned next_offset = 0;
to_telescope(env, lctx, st.ngen(), type, telescope);
lean_assert(telescope.size() >= nparams);
for (unsigned i = nparams; i < telescope.size(); i++) {
expr ftype = lctx.get_type(telescope[i]);
if (is_irrelevant_type(st, lctx, ftype)) {
result.push_back(field_info::mk_irrelevant());
} else if (unboxed) {
type_checker tc(st, lctx);
ftype = tc.whnf(ftype);
if (is_usize_type(ftype)) {
result.push_back(field_info::mk_usize(next_usize));
next_usize++;
} else if (optional<unsigned> sz = is_builtin_scalar(ftype)) {
result.push_back(field_info::mk_scalar(*sz, next_offset, ftype));
next_offset += *sz;
} else if (optional<unsigned> sz = is_enum_type(env, ftype)) {
optional<expr> uint = to_uint_type(*sz);
if (!uint) throw exception("code generation failed, enumeration type is too big");
result.push_back(field_info::mk_scalar(*sz, next_offset, *uint));
next_offset += *sz;
} else {
result.push_back(field_info::mk_object(next_object));
next_object++;
}
} else {
result.push_back(field_info::mk_object(next_object));
next_object++;
}
}
unsigned nobjs = next_object;
unsigned nusizes = next_usize;
if (unboxed) {
/* Remark:
- usize fields are stored after object fields.
- regular scalar fields are stored after object and usize fields */
for (field_info & info : result) {
switch (info.m_kind) {
case field_info::Scalar:
info.m_offset += (nobjs + nusizes) * sizeof(void*);
break;
case field_info::USize:
info.m_offset += nobjs * sizeof(void*);
break;
default:
break;
}
}
}
}
static cnstr_info get_cnstr_info(type_checker::state & st, bool unboxed, name const & n) {
buffer<field_info> finfos;
get_cnstr_info_core(st, unboxed, n, finfos);
unsigned cidx = get_constructor_idx(st.env(), n);
return cnstr_info(cidx, to_list(finfos));
}
class to_llnf_fn {
typedef name_hash_set name_set;
typedef name_hash_map<cnstr_info> cnstr_info_cache;
@ -299,83 +374,13 @@ class to_llnf_fn {
unsigned m_next_idx{1};
unsigned m_next_jp_idx{1};
cnstr_info_cache m_cnstr_info_cache;
enum_cache m_enum_cache;
environment const & env() const { return m_st.env(); }
name_generator & ngen() { return m_st.ngen(); }
optional<unsigned> is_enum_type(expr const & type) {
expr const & I = get_app_fn(type);
if (!is_constant(I)) return optional<unsigned>();
auto it = m_enum_cache.find(const_name(I));
if (it != m_enum_cache.end())
return it->second;
optional<unsigned> r = ::lean::is_enum_type(env(), const_name(I));
m_enum_cache.insert(mk_pair(const_name(I), r));
return r;
}
void get_cnstr_info_core(name const & n, buffer<field_info> & result) {
constant_info info = env().get(n);
lean_assert(info.is_constructor());
constructor_val val = info.to_constructor_val();
expr type = info.get_type();
name I_name = val.get_induct();
unsigned nparams = val.get_nparams();
local_ctx lctx;
buffer<expr> telescope;
unsigned next_object = 0;
unsigned next_usize = 0;
unsigned next_offset = 0;
to_telescope(env(), lctx, ngen(), type, telescope);
lean_assert(telescope.size() >= nparams);
for (unsigned i = nparams; i < telescope.size(); i++) {
expr ftype = lctx.get_type(telescope[i]);
if (is_irrelevant_type(m_st, lctx, ftype)) {
result.push_back(field_info::mk_irrelevant());
} else if (m_unboxed) {
type_checker tc(m_st, lctx);
ftype = tc.whnf(ftype);
if (is_usize_type(ftype)) {
result.push_back(field_info::mk_usize(next_usize));
next_usize++;
} else if (optional<unsigned> sz = is_builtin_scalar(ftype)) {
result.push_back(field_info::mk_scalar(*sz, next_offset, ftype));
next_offset += *sz;
} else if (optional<unsigned> sz = is_enum_type(ftype)) {
optional<expr> uint = to_uint_type(*sz);
if (!uint) throw exception("code generation failed, enumeration type is too big");
result.push_back(field_info::mk_scalar(*sz, next_offset, *uint));
next_offset += *sz;
} else {
result.push_back(field_info::mk_object(next_object));
next_object++;
}
} else {
result.push_back(field_info::mk_object(next_object));
next_object++;
}
}
unsigned nobjs = next_object;
unsigned nusizes = next_usize;
if (m_unboxed) {
/* Remark:
- usize fields are stored after object fields.
- regular scalar fields are stored after object and usize fields */
for (field_info & info : result) {
switch (info.m_kind) {
case field_info::Scalar:
info.m_offset += (nobjs + nusizes) * sizeof(void*);
break;
case field_info::USize:
info.m_offset += nobjs * sizeof(void*);
break;
default:
break;
}
}
}
return ::lean::is_enum_type(env(), type);
}
unsigned get_arity(name const & n) const {
@ -419,10 +424,7 @@ class to_llnf_fn {
auto it = m_cnstr_info_cache.find(n);
if (it != m_cnstr_info_cache.end())
return it->second;
buffer<field_info> finfos;
get_cnstr_info_core(n, finfos);
unsigned cidx = get_constructor_idx(env(), n);
cnstr_info r(cidx, to_list(finfos));
cnstr_info r = ::lean::get_cnstr_info(m_st, m_unboxed, n);
m_cnstr_info_cache.insert(mk_pair(n, r));
return r;
}
@ -646,7 +648,7 @@ class to_llnf_fn {
}
j++;
}
expr r = mk_app(mk_llnf_cnstr(cidx, k_info.m_num_usizes, k_info.m_scalar_sz), obj_args);
expr r = mk_app(mk_llnf_cnstr(I, cidx, k_info.m_num_usizes, k_info.m_scalar_sz), obj_args);
j = nparams;
unsigned offset = 0;
unsigned uidx = 0;
@ -773,6 +775,7 @@ public:
}
};
/* Push projections inside `cases_on` branches. */
class push_proj_fn {
environment m_env;
name_generator m_ngen;
@ -872,7 +875,7 @@ class push_proj_fn {
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> fvars;
while (is_let(e)) {
expr val = instantiate_rev(let_value(e), fvars.size(), fvars.data());
expr val = visit(instantiate_rev(let_value(e), fvars.size(), fvars.data()));
expr new_fvar = m_lctx.mk_local_decl(m_ngen, let_name(e), let_type(e), val);
fvars.push_back(new_fvar);
e = let_body(e);
@ -900,11 +903,237 @@ public:
};
class insert_reset_reuse_fn {
type_checker::state m_st;
bool m_unboxed;
local_ctx m_lctx;
name m_r{"_r"};
unsigned m_next_idx{1};
struct opt_ctx {
name I;
cnstr_info cinfo;
expr x;
};
struct replace_ctx {
name I;
cnstr_info cinfo;
expr x;
expr reset_x;
replace_ctx(opt_ctx const & octx, expr const & rx):
I(octx.I), cinfo(octx.cinfo), x(octx.x), reset_x(rx) {}
};
environment const & env() { return m_st.env(); }
name_generator & ngen() { return m_st.ngen(); }
expr replace_cnstr(replace_ctx const & rctx, expr const & e) {
buffer<expr> args;
expr const & c = get_app_args(e, args);
name I; unsigned cidx; unsigned nusizes; unsigned ssz;
lean_verify(is_llnf_cnstr(c, I, cidx, nusizes, ssz));
if (I != rctx.I) {
/* Heuristic: we don't want to reuse cells from different types even when they are compatible
because it produces counterintuitive behavior. Here is an example:
```
@list.cases_on a
(@prod.cases_on a_1 (λ fst snd, (punit.star, snd)))
(λ a_hd a_tl,
@prod.cases_on a_1
(λ fst snd,
let _x_1 := nat.add snd a_hd,
_x_2 := (punit.star, _x_1)
in list.mmap'._main._at.accum._spec_1 a_tl _x_2))
```
Without this heuristic, we will try to construct `(punit.star, _x_1)` re-using `a` instead of `a_1`. */
return e;
}
if (args.size() != rctx.cinfo.m_num_objs ||
nusizes != rctx.cinfo.m_num_usizes ||
ssz != rctx.cinfo.m_scalar_sz) {
/* This constructor is not compatible with major premise */
return e;
}
expr r = mk_app(mk_llnf_reuse(cidx, nusizes, ssz, cidx != rctx.cinfo.m_cidx), rctx.reset_x);
return mk_app(r, args);
}
expr replace_app(replace_ctx const & rctx, expr const & e) {
if (is_llnf_cnstr(get_app_fn(e))) {
return replace_cnstr(rctx, e);
} else if (is_cases_on_app(env(), e)) {
lean_assert(!m_replaced);
buffer<expr> args;
expr const & fn = get_app_args(e, args);
bool modified = false;
for (unsigned i = 1; i < args.size(); i++) {
expr new_arg = replace(rctx, args[i]);
if (new_arg != args[i]) {
modified = true;
args[i] = new_arg;
}
}
return modified ? mk_app(fn, args) : e;
} else {
return e;
}
}
expr replace_let(replace_ctx const & rctx, expr const & e) {
expr new_value = replace(rctx, let_value(e));
if (new_value != let_value(e)) {
return update_let(e, let_type(e), new_value, let_body(e));
} else {
expr new_body = replace(rctx, let_body(e));
return update_let(e, let_type(e), new_value, new_body);
}
}
expr replace_lambda(replace_ctx const & rctx, expr const & e) {
expr new_body = replace(rctx, binding_body(e));
return update_binding(e, binding_domain(e), new_body);
}
expr replace(replace_ctx const & rctx, expr const & e) {
switch (e.kind()) {
case expr_kind::App: return replace_app(rctx, e);
case expr_kind::Let: return replace_let(rctx, e);
case expr_kind::Lambda: return replace_lambda(rctx, e);
default: return e;
}
}
name next_reset_name() {
name r(m_r, m_next_idx);
m_next_idx++;
return r;
}
expr replace(opt_ctx const & octx, expr const & e) {
expr reset_x = mk_fvar(ngen().next());
replace_ctx rctx(octx, reset_x);
expr new_e = replace(rctx, e);
if (e == new_e) return e;
expr reset = mk_app(mk_llnf_reset(octx.cinfo.m_num_objs), octx.x);
return ::lean::mk_let(next_reset_name(), mk_enf_object_type(), reset, abstract(new_e, reset_x));
}
expr opt_let(opt_ctx const & octx, expr e) {
lean_assert(is_let(e));
lean_assert(has_fvar(e, octx.x));
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> fvars;
while (is_let(e)) {
expr val = instantiate_rev(let_value(e), fvars.size(), fvars.data());
expr new_fvar = m_lctx.mk_local_decl(ngen(), let_name(e), let_type(e), val);
fvars.push_back(new_fvar);
if (has_fvar(let_value(e), octx.x) && !has_fvar(let_body(e), octx.x)) {
expr new_body = instantiate_rev(let_body(e), fvars.size(), fvars.data());
new_body = replace(octx, new_body);
return m_lctx.mk_lambda(fvars, new_body);
}
e = let_body(e);
}
lean_assert(has_fvar(e, octx.x));
expr new_body = opt(octx, instantiate_rev(e, fvars.size(), fvars.data()));
return m_lctx.mk_lambda(fvars, new_body);
}
expr opt_cases(opt_ctx const & octx, expr const & e) {
buffer<expr> args;
expr const & fn = get_app_args(e, args);
bool modified = false;
for (unsigned i = 1; i < args.size(); i++) {
expr arg = args[i];
expr new_arg = opt(octx, arg);
if (arg != new_arg) {
modified = true;
args[i] = new_arg;
}
}
return modified ? mk_app(fn, args) : e;
}
expr opt(opt_ctx const & octx, expr const & e) {
if (!has_fvar(e, octx.x)) {
return replace(octx, e);
} else if (is_let(e)) {
return opt_let(octx, e);
} else if (is_cases_on_app(env(), e)) {
return opt_cases(octx, e);
} else {
return e;
}
}
expr optimize_cases_on(expr const & e) {
lean_assert(is_cases_on_app(env(), e));
buffer<expr> cases_args;
expr const & cases_fn = get_app_args(e, cases_args);
name const & I_name = const_name(cases_fn).get_prefix();
expr const & major = cases_args[0];
buffer<name> cnames;
get_constructor_names(env(), I_name, cnames);
lean_assert(cases_args.size() == cnames.size() + 1);
for (unsigned i = 1; i < cases_args.size(); i++) {
cnstr_info cinfo = get_cnstr_info(m_st, m_unboxed, cnames[i-1]);
expr minor = optimize(cases_args[i]);
minor = opt(opt_ctx{I_name, cinfo, major}, minor);
cases_args[i] = minor;
}
return mk_app(cases_fn, cases_args);
}
expr optimize_app(expr const & e) {
if (is_cases_on_app(env(), e))
return optimize_cases_on(e);
else
return e;
}
expr optimize_lambda(expr e) {
lean_assert(is_lambda(e));
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> fvars;
while (is_lambda(e)) {
/* Types are ignored in compilation steps. So, we do not invoke visit for d. */
expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), binding_domain(e));
fvars.push_back(new_fvar);
e = binding_body(e);
}
expr new_body = optimize(instantiate_rev(e, fvars.size(), fvars.data()));
return m_lctx.mk_lambda(fvars, new_body);
}
expr optimize_let(expr e) {
lean_assert(is_let(e));
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> fvars;
while (is_let(e)) {
expr val = optimize(instantiate_rev(let_value(e), fvars.size(), fvars.data()));
expr new_fvar = m_lctx.mk_local_decl(ngen(), let_name(e), let_type(e), val);
fvars.push_back(new_fvar);
e = let_body(e);
}
expr new_body = optimize(instantiate_rev(e, fvars.size(), fvars.data()));
return m_lctx.mk_lambda(fvars, new_body);
}
expr optimize(expr const & e) {
switch (e.kind()) {
case expr_kind::Lambda: return optimize_lambda(e);
case expr_kind::App: return optimize_app(e);
case expr_kind::Let: return optimize_let(e);
default: return e;
}
}
public:
insert_reset_reuse_fn(environment const & /* env */) {}
insert_reset_reuse_fn(environment const & env, bool unboxed):
m_st(env), m_unboxed(unboxed) {}
expr operator()(expr const & e) {
// TODO(Leo)
return e;
return optimize(e);
}
};
@ -2066,7 +2295,7 @@ pair<environment, comp_decls> to_llnf(environment const & env, comp_decls const
for (comp_decl const & d : ds) {
expr new_v = to_llnf_fn(new_env, unboxed)(d.snd());
new_v = push_proj_fn(new_env)(new_v);
new_v = insert_reset_reuse_fn(new_env)(new_v);
new_v = insert_reset_reuse_fn(new_env, unboxed)(new_v);
rs.push_back(comp_decl(d.fst(), new_v));
if (unboxed) {
if (optional<pair<environment, comp_decl>> p = mk_boxed_version(new_env, d.fst(), get_num_nested_lambdas(d.snd()))) {

View file

@ -17,7 +17,11 @@ optional<pair<environment, comp_decl>> mk_boxed_version(environment env, name co
bool is_llnf_apply(expr const & e);
bool is_llnf_closure(expr const & e);
bool is_llnf_cnstr(expr const & e, unsigned & cidx, unsigned & nusize, unsigned & ssz);
bool is_llnf_cnstr(expr const & e, name & I, unsigned & cidx, unsigned & nusize, unsigned & ssz);
inline bool is_llnf_cnstr(expr const & e, unsigned & cidx, unsigned & nusize, unsigned & ssz) {
name I;
return is_llnf_cnstr(e, I, cidx, nusize, ssz);
}
bool is_llnf_reuse(expr const & e, unsigned & cidx, unsigned & nusize, unsigned & ssz, bool & updt_cidx);
bool is_llnf_reset(expr const & e, unsigned & n);
bool is_llnf_proj(expr const & e, unsigned & idx);