feat(library/compiler/specialize): code specialization

TODO:
- Cache results at `specialize_ext`
- Cleanup

It is not feasible to run code specializer without cache: code explosion.
This commit is contained in:
Leonardo de Moura 2018-10-16 15:48:18 -07:00
parent af682a0981
commit 611f6ae780
4 changed files with 388 additions and 95 deletions

View file

@ -49,14 +49,16 @@ protected:
template<bool is_lambda> expr mk_binding(unsigned num, expr const * fvars, expr const & b) const;
local_decl mk_local_decl(name const & n, name const & un, expr const & type, binder_info bi);
local_decl mk_local_decl(name const & n, name const & un, expr const & type, expr const & value);
public:
local_ctx():m_next_idx(0) {}
bool empty() const { return m_idx2local_decl.empty(); }
/* Low level `mk_local_decl` */
local_decl mk_local_decl(name const & n, name const & un, expr const & type, binder_info bi);
/* Low level `mk_local_decl` */
local_decl mk_local_decl(name const & n, name const & un, expr const & type, expr const & value);
expr mk_local_decl(name_generator & g, name const & un, expr const & type, binder_info bi = mk_binder_info()) {
return mk_local_decl(g.next(), un, type, bi).mk_ref();
}

View file

@ -96,7 +96,8 @@ environment compile(environment const & env, options const & opts, names const &
}
comp_decls ds = to_comp_decls(env, cs);
auto simp = [&](environment const & env, expr const & e) { return csimp(env, e, csimp_cfg(opts)); };
csimp_cfg cfg(opts);
auto simp = [&](environment const & env, expr const & e) { return csimp(env, e, cfg); };
trace_compiler(name({"compiler", "input"}), ds);
ds = apply(eta_expand, env, ds);
trace_compiler(name({"compiler", "eta_expand"}), ds);
@ -113,7 +114,7 @@ environment compile(environment const & env, options const & opts, names const &
ds = apply(max_sharing, ds);
trace_compiler(name({"compiler", "stage1"}), ds);
environment new_env = cache_stage1(env, ds);
std::tie(new_env, ds) = specialize(new_env, ds);
std::tie(new_env, ds) = specialize(new_env, ds, cfg);
trace_compiler(name({"compiler", "specialize"}), ds);
ds = apply(elim_dead_let, ds);
trace_compiler(name({"compiler", "elim_dead_let"}), ds);

View file

@ -11,6 +11,7 @@ Author: Leonardo de Moura
#include "library/module.h"
#include "library/attribute_manager.h"
#include "library/compiler/util.h"
#include "library/compiler/csimp.h"
#include "library/trace.h"
@ -287,9 +288,14 @@ environment update_spec_info(environment const & env, comp_decls const & ds) {
class specialize_fn {
type_checker::state m_st;
csimp_cfg m_cfg;
specialize_ext m_ext;
local_ctx m_lctx;
buffer<comp_decl> m_new_decls;
name m_base_name;
name m_at;
name m_spec;
unsigned m_next_idx{1};
environment const & env() { return m_st.env(); }
@ -348,7 +354,135 @@ class specialize_fn {
return e;
}
void collect_deps(expr e, name_set & collected, buffer<expr> & new_params, buffer<expr> & let_vars) {
struct spec_ctx {
typedef rb_expr_map<name> cache;
names m_mutual;
buffer<expr> m_params;
buffer<expr> m_let_vars;
cache m_cache;
buffer<comp_decl> m_pre_decls;
bool in_mutual_decl(name const & n) const {
return std::find(m_mutual.begin(), m_mutual.end(), n) != m_mutual.end();
}
};
void get_arg_kinds(name const & fn, buffer<spec_arg_kind> & kinds) {
spec_info const * info = m_ext.m_spec_info.find(fn);
lean_assert(info);
to_buffer(info->get_arg_kinds(), kinds);
}
static void to_bool_mask(buffer<spec_arg_kind> const & kinds, bool has_attr, buffer<bool> & mask) {
unsigned sz = kinds.size();
mask.resize(sz, false);
unsigned i = sz;
bool found_inst = false;
bool first = true;
while (i > 0) {
--i;
switch (kinds[i]) {
case spec_arg_kind::Other:
break;
case spec_arg_kind::FixedInst:
mask[i] = true;
if (first) mask.shrink(i+1);
first = false;
found_inst = true;
break;
case spec_arg_kind::FixedHO:
case spec_arg_kind::FixedNeutral:
case spec_arg_kind::Fixed:
if (has_attr || found_inst) {
mask[i] = true;
if (first)
mask.shrink(i+1);
first = false;
}
break;
}
}
}
void get_bool_mask(name const & fn, buffer<bool> & mask) {
buffer<spec_arg_kind> kinds;
get_arg_kinds(fn, kinds);
to_bool_mask(kinds, has_specialize_attribute(env(), fn), mask);
}
name mk_spec_name(name const & fn) {
name r = fn + m_at + m_base_name + (m_spec.append_after(m_next_idx));
m_next_idx++;
return r;
}
static expr mk_cache_key(expr const & fn, buffer<optional<expr>> const & mask) {
expr r = fn;
for (optional<expr> const & b : mask) {
if (b)
r = mk_app(r, *b);
else
r = mk_app(r, expr());
}
return r;
}
bool is_specialize_candidate(expr const & fn, buffer<expr> const & args) {
lean_assert(is_constant(fn));
buffer<spec_arg_kind> kinds;
get_arg_kinds(const_name(fn), kinds);
if (!has_specialize_attribute(env(), const_name(fn)) && !has_fixed_inst_arg(kinds))
return false; /* Nothing to specialize */
unsigned spec_arity = get_specialization_arity(kinds);
if (spec_arity == 0)
return false; /* Nothing to specialize */
if (spec_arity > args.size()) {
/* We do not perform partial specialization.
We only specialize if all fixed arguments have been provided. */
return false;
}
type_checker tc(m_st, m_lctx);
for (unsigned i = 0; i < args.size(); i++) {
if (i >= kinds.size())
break;
spec_arg_kind k = kinds[i];
expr w;
switch (k) {
case spec_arg_kind::FixedNeutral:
break;
case spec_arg_kind::FixedInst:
/* We specialize this kind of argument if it reduces to a constructor application.
Type class instances arguments are usually free variables bound to lambda declarations,
or quickly reduce to constructor applications. So, the following `whnf` is probably
harmless. */
w = tc.whnf(args[i]);
if (is_constructor_app(env(), w))
return true;
break;
case spec_arg_kind::FixedHO:
/* We specialize higher-order arguments if they are lambda applications or
a constant application.
Remark: it is not feasible to invoke whnf since it may consume a lot of time. */
w = find(args[i]);
if (is_lambda(w) || is_constant(get_app_fn(w)))
return true;
break;
case spec_arg_kind::Fixed:
/* We specialize this kind of argument if they are constructor applications or literals.
Remark: it is not feasible to invoke whnf since it may consume a lot of time. */
w = find(args[i]);
if (is_constructor_app(env(), w) || is_lit(w))
return true;
break;
case spec_arg_kind::Other:
break;
}
}
return false;
}
void collect_deps(expr e, name_set & collected, spec_ctx & ctx) {
buffer<expr> todo;
while (true) {
for_each(e, [&](expr const & x, unsigned) {
@ -356,10 +490,10 @@ class specialize_fn {
if (is_fvar(x) && !collected.contains(fvar_name(x))) {
collected.insert(fvar_name(x));
if (optional<expr> v = m_lctx.get_local_decl(x).get_value()) {
let_vars.push_back(x);
ctx.m_let_vars.push_back(x);
todo.push_back(*v);
} else {
new_params.push_back(x);
ctx.m_params.push_back(x);
}
}
return true;
@ -371,15 +505,22 @@ class specialize_fn {
}
}
optional<expr> specialize(expr const & fn, buffer<expr> const & args, names const & mutual, buffer<spec_arg_kind> const & kinds, bool has_attr) {
void sort_fvars(buffer<expr> & fvars) {
std::sort(fvars.begin(), fvars.end(),
[&](expr const & x, expr const & y) {
return m_lctx.get_local_decl(x).get_idx() < m_lctx.get_local_decl(y).get_idx();
});
}
/* Initialize `spec_ctx` fields: `m_params`, `m_let_vars`. */
void specialize_init_deps(expr const & fn, buffer<expr> const & args, spec_ctx & ctx) {
lean_assert(is_constant(fn));
buffer<spec_arg_kind> kinds;
get_arg_kinds(const_name(fn), kinds);
bool has_attr = has_specialize_attribute(env(), const_name(fn));
name_set collected;
buffer<expr> new_params;
buffer<expr> let_vars;
unsigned sz = get_specialization_arity(kinds);
lean_assert(sz <= args.size());
unsigned sz = kinds.size();
unsigned i = sz;
buffer<bool> mask;
mask.resize(args.size(), false);
bool found_inst = false;
while (i > 0) {
--i;
@ -387,33 +528,234 @@ class specialize_fn {
case spec_arg_kind::Other:
break;
case spec_arg_kind::FixedInst:
mask[i] = true;
collect_deps(args[i], collected, new_params, let_vars);
collect_deps(args[i], collected, ctx);
found_inst = true;
break;
case spec_arg_kind::FixedHO:
case spec_arg_kind::FixedNeutral:
case spec_arg_kind::Fixed:
if (has_attr || found_inst) {
mask[i] = true;
collect_deps(args[i], collected, new_params, let_vars);
collect_deps(args[i], collected, ctx);
}
break;
}
}
std::sort(new_params.begin(), new_params.end(),
[&](expr const & x, expr const & y) { return m_lctx.get_local_decl(x).get_idx() < m_lctx.get_local_decl(y).get_idx(); });
std::sort(let_vars.begin(), let_vars.end(),
[&](expr const & x, expr const & y) { return m_lctx.get_local_decl(x).get_idx() < m_lctx.get_local_decl(y).get_idx(); });
sort_fvars(ctx.m_params);
sort_fvars(ctx.m_let_vars);
lean_trace(name({"compiler", "spec_candidate"}),
tout() << "candidate: " << mk_app(fn, args) << "\nclosure:";
for (expr const & p : new_params) tout() << " " << p;
for (expr const & x : let_vars) tout() << " " << x;
tout() << "\nmask:";
for (bool m : mask) tout() << " " << m;
for (expr const & p : ctx.m_params) tout() << " " << p;
for (expr const & x : ctx.m_let_vars) tout() << " " << x;
tout() << "\n";);
// TODO(Leo):
return none_expr();
}
static bool contains(buffer<optional<expr>> const & mask, expr const & e) {
for (optional<expr> const & o : mask) {
if (o && *o == e)
return true;
}
return false;
}
optional<expr> adjust_rec_apps(expr e, buffer<optional<expr>> const & mask, spec_ctx & ctx) {
switch (e.kind()) {
case expr_kind::App:
if (is_cases_on_app(env(), e)) {
buffer<expr> args;
expr const & c = get_app_args(e, args);
/* visit minor premises */
unsigned minor_idx; unsigned minors_end;
std::tie(minor_idx, minors_end) = get_cases_on_minors_range(env(), const_name(c));
for (; minor_idx < minors_end; minor_idx++) {
optional<expr> new_arg = adjust_rec_apps(args[minor_idx], mask, ctx);
if (!new_arg) return none_expr();
args[minor_idx] = *new_arg;
}
return some_expr(mk_app(c, args));
} else {
expr const & fn = get_app_fn(e);
if (!is_constant(fn) || !ctx.in_mutual_decl(const_name(fn)))
return some_expr(e);
buffer<expr> args;
get_app_args(e, args);
buffer<bool> bmask;
get_bool_mask(const_name(fn), bmask);
lean_assert(bmask.size() <= args.size());
buffer<optional<expr>> new_mask;
bool found = false;
for (unsigned i = 0; i < bmask.size(); i++) {
if (bmask[i] && contains(mask, args[i])) {
found = true;
new_mask.push_back(some_expr(args[i]));
} else {
new_mask.push_back(none_expr());
}
}
if (!found)
return some_expr(e);
optional<name> new_fn_name = spec_preprocess(fn, new_mask, ctx);
if (!new_fn_name) return none_expr();
expr r = mk_constant(*new_fn_name);
r = mk_app(r, ctx.m_params);
for (unsigned i = 0; i < bmask.size(); i++) {
if (!bmask[i] || !contains(mask, args[i]))
r = mk_app(r, args[i]);
}
for (unsigned i = bmask.size(); i < args.size(); i++) {
r = mk_app(r, args[i]);
}
return some_expr(r);
}
case expr_kind::Lambda: {
buffer<expr> entries;
while (is_lambda(e)) {
entries.push_back(e);
e = binding_body(e);
}
optional<expr> new_e = adjust_rec_apps(e, mask, ctx);
if (!new_e) return none_expr();
expr r = *new_e;
unsigned i = entries.size();
while (i > 0) {
--i;
expr l = entries[i];
r = mk_lambda(binding_name(l), binding_domain(l), r);
}
return some_expr(r);
}
case expr_kind::Let: {
buffer<pair<expr, expr>> entries;
while (is_let(e)) {
optional<expr> v = adjust_rec_apps(let_value(e), mask, ctx);
if (!v) return none_expr();
expr new_val = *v;
entries.emplace_back(e, new_val);
e = let_body(e);
}
optional<expr> new_e = adjust_rec_apps(e, mask, ctx);
if (!new_e) return none_expr();
expr r = *new_e;
unsigned i = entries.size();
while (i > 0) {
--i;
expr l = entries[i].first;
expr v = entries[i].second;
r = mk_let(let_name(l), let_type(l), v, r);
}
return some_expr(r);
}
default:
return some_expr(e);
}
}
optional<name> spec_preprocess(expr const & fn, buffer<optional<expr>> const & mask, spec_ctx & ctx) {
lean_assert(is_constant(fn));
lean_assert(ctx.in_mutual_decl(const_name(fn)));
expr key = mk_cache_key(fn, mask);
if (name const * r = ctx.m_cache.find(key)) {
return optional<name>(*r);
}
optional<constant_info> info = env().find(mk_cstage1_name(const_name(fn)));
if (!info || !info->is_definition()) return optional<name>(); // failed
name new_name = mk_spec_name(const_name(fn));
ctx.m_cache.insert(key, new_name);
expr new_code = instantiate_value_lparams(*info, const_levels(fn));
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> fvars;
buffer<expr> new_fvars;
for (optional<expr> const & b : mask) {
lean_assert(is_lambda(new_code));
if (b) {
lean_assert(is_fvar(*b));
fvars.push_back(*b);
} else {
expr type = instantiate_rev(binding_domain(new_code), fvars.size(), fvars.data());
expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(new_code), type, binding_info(new_code));
new_fvars.push_back(new_fvar);
fvars.push_back(new_fvar);
}
new_code = binding_body(new_code);
}
new_code = instantiate_rev(new_code, fvars.size(), fvars.data());
optional<expr> c = adjust_rec_apps(new_code, mask, ctx);
if (!c) return optional<name>();
new_code = *c;
new_code = m_lctx.mk_lambda(new_fvars, new_code);
ctx.m_pre_decls.push_back(comp_decl(new_name, new_code));
lean_trace(name({"compiler", "specialize"}), tout() << "new specialization " << new_name << " :=\n" << new_code << "\n";);
return optional<name>(new_name);
}
void mk_new_decl(comp_decl const & pre_decl, buffer<expr> const & fvars, buffer<expr> const & fvar_vals, spec_ctx & ctx) {
lean_assert(fvars.size() == fvar_vals.size());
name n = pre_decl.fst();
expr code = pre_decl.snd();
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> new_fvars;
while (is_lambda(code)) {
expr type = instantiate_rev(binding_domain(code), new_fvars.size(), new_fvars.data());
expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(code), type, binding_info(code));
new_fvars.push_back(new_fvar);
code = binding_body(code);
}
code = instantiate_rev(code, new_fvars.size(), new_fvars.data());
/* Add fvars decls */
type_checker tc(m_st, m_lctx);
buffer<expr> new_let_decls;
name y("_y");
for (unsigned i = 0; i < fvars.size(); i++) {
expr type = tc.infer(fvar_vals[i]);
expr new_fvar = m_lctx.mk_local_decl(fvar_name(fvars[i]), y.append_after(i+1), type, fvar_vals[i]).mk_ref();
new_let_decls.push_back(new_fvar);
}
code = m_lctx.mk_lambda(new_let_decls, code);
code = m_lctx.mk_lambda(ctx.m_let_vars, code);
code = m_lctx.mk_lambda(new_fvars, code);
code = m_lctx.mk_lambda(ctx.m_params, code);
lean_assert(!has_fvar(code));
code = csimp(env(), code, m_cfg);
code = visit(code);
tout() << n << " :=\n" << code << "\n";
m_new_decls.push_back(comp_decl(n, code));
}
optional<expr> specialize(expr const & fn, buffer<expr> const & args, spec_ctx & ctx) {
if (!is_specialize_candidate(fn, args))
return none_expr();
specialize_init_deps(fn, args, ctx);
buffer<bool> bmask;
get_bool_mask(const_name(fn), bmask);
buffer<optional<expr>> mask;
buffer<expr> fvars;
buffer<expr> fvar_vals;
for (unsigned i = 0; i < bmask.size(); i++) {
if (bmask[i]) {
name n = ngen().next();
expr fvar = mk_fvar(n);
fvars.push_back(fvar);
fvar_vals.push_back(args[i]);
mask.push_back(some_expr(fvar));
} else {
mask.push_back(none_expr());
}
}
optional<name> new_fn_name = spec_preprocess(fn, mask, ctx);
if (!new_fn_name)
return none_expr();
for (comp_decl const & pre_decl : ctx.m_pre_decls) {
mk_new_decl(pre_decl, fvars, fvar_vals, ctx);
}
expr r = mk_constant(*new_fn_name);
r = mk_app(r, ctx.m_params);
for (unsigned i = 0; i < bmask.size(); i++) {
if (!bmask[i])
r = mk_app(r, args[i]);
}
for (unsigned i = bmask.size(); i < args.size(); i++) {
r = mk_app(r, args[i]);
}
return some_expr(r);
}
expr visit_app(expr const & e) {
@ -427,66 +769,11 @@ class specialize_fn {
|| (is_instance(env(), const_name(fn)) && !has_specialize_attribute(env(), const_name(fn)))) {
return e;
}
specialize_ext ext = get_extension(env());
spec_info const * info = ext.m_spec_info.find(const_name(fn));
spec_info const * info = m_ext.m_spec_info.find(const_name(fn));
if (!info) return e;
bool has_attr = has_specialize_attribute(env(), const_name(fn));
buffer<spec_arg_kind> kinds;
to_buffer(info->get_arg_kinds(), kinds);
if (!has_attr && !has_fixed_inst_arg(kinds))
return e; /* Nothing to specialize */
unsigned spec_arity = get_specialization_arity(kinds);
if (spec_arity == 0)
return e; /* Nothing to specialize */
if (spec_arity > args.size()) {
/* We do not perform partial specialization.
We only specialize if all fixed arguments have been provided. */
return e;
}
type_checker tc(m_st, m_lctx);
bool is_candidate = false;
for (unsigned i = 0; i < args.size(); i++) {
if (i >= kinds.size())
break;
spec_arg_kind k = kinds[i];
expr w;
switch (k) {
case spec_arg_kind::FixedNeutral:
break;
case spec_arg_kind::FixedInst:
/* We specialize this kind of argument if it reduces to a constructor application.
Type class instances arguments are usually free variables bound to lambda declarations,
or quickly reduce to constructor applications. So, the following `whnf` is probably
harmless. */
w = tc.whnf(args[i]);
if (is_constructor_app(env(), w))
is_candidate = true;
break;
case spec_arg_kind::FixedHO:
/* We specialize higher-order arguments if they are lambda applications or
a constant application.
Remark: it is not feasible to invoke whnf since it may consume a lot of time. */
w = find(args[i]);
if (is_lambda(w) || is_constant(get_app_fn(w)))
is_candidate = true;
break;
case spec_arg_kind::Fixed:
/* We specialize this kind of argument if they are constructor applications or literals.
Remark: it is not feasible to invoke whnf since it may consume a lot of time. */
w = find(args[i]);
if (is_constructor_app(env(), w) || is_lit(w))
is_candidate = true;
break;
case spec_arg_kind::Other:
break;
}
if (is_candidate)
break;
}
if (!is_candidate)
return e;
if (optional<expr> r = specialize(fn, args, info->get_mutual_decls(), kinds, has_attr))
spec_ctx ctx;
ctx.m_mutual = info->get_mutual_decls();
if (optional<expr> r = specialize(fn, args, ctx))
return *r;
else
return e;
@ -503,8 +790,8 @@ class specialize_fn {
}
public:
specialize_fn(environment const & env):
m_st(env) {}
specialize_fn(environment const & env, csimp_cfg const & cfg):
m_st(env), m_cfg(cfg), m_ext(get_extension(env)), m_at("_at"), m_spec("_spec") {}
pair<environment, comp_decls> operator()(comp_decl const & d) {
m_base_name = d.fst();
@ -514,16 +801,18 @@ public:
}
};
pair<environment, comp_decls> specialize_core(environment const & env, comp_decl const & d) {
return specialize_fn(env)(d);
pair<environment, comp_decls> specialize_core(environment const & env, comp_decl const & d, csimp_cfg const & cfg) {
// TODO(Leo): we still need to implement main cache.
// return specialize_fn(env, cfg)(d);
return mk_pair(env, comp_decls(d));
}
pair<environment, comp_decls> specialize(environment env, comp_decls const & ds) {
pair<environment, comp_decls> specialize(environment env, comp_decls const & ds, csimp_cfg const & cfg) {
env = update_spec_info(env, ds);
comp_decls r;
for (comp_decl const & d : ds) {
comp_decls new_ds;
std::tie(env, new_ds) = specialize_core(env, d);
std::tie(env, new_ds) = specialize_core(env, d, cfg);
r = append(r, new_ds);
}
return mk_pair(env, r);

View file

@ -7,8 +7,9 @@ Author: Leonardo de Moura
#pragma once
#include "kernel/environment.h"
#include "library/compiler/util.h"
#include "library/compiler/csimp.h"
namespace lean {
pair<environment, comp_decls> specialize(environment env, comp_decls const & ds);
pair<environment, comp_decls> specialize(environment env, comp_decls const & ds, csimp_cfg const & cfg);
void initialize_specialize();
void finalize_specialize();
}