feat(library/compiler/specialize): add spec_info

Store which arguments can be specialized.
This commit is contained in:
Leonardo de Moura 2018-10-15 12:54:34 -07:00
parent 10b99a678c
commit 9ca4c362ae
3 changed files with 209 additions and 12 deletions

View file

@ -56,16 +56,6 @@ comp_decls apply(F && f, comp_decls const & ds) {
return map(ds, [&](comp_decl const & d) { return comp_decl(d.fst(), f(d.snd())); });
}
static pair<environment, comp_decls> specialize(environment env, comp_decls const & ds) {
comp_decls r;
for (comp_decl const & d : ds) {
comp_decls new_ds;
std::tie(env, new_ds) = specialize(env, d);
r = append(r, new_ds);
}
return mk_pair(env, r);
}
static comp_decls lambda_lifting(environment const & env, comp_decls const & ds) {
comp_decls r;
for (comp_decl const & d : ds) {

View file

@ -6,11 +6,202 @@ Author: Leonardo de Moura
*/
#include "runtime/flet.h"
#include "kernel/instantiate.h"
#include "library/module.h"
#include "library/compiler/util.h"
#include "library/trace.h"
namespace lean {
enum class spec_arg_kind { Fixed, FixedInst, Other };
static spec_arg_kind to_spec_arg_kind(object_ref const & r) {
lean_assert(is_scalar(r)); return static_cast<spec_arg_kind>(unbox(r.raw()));
}
typedef objects spec_arg_kinds;
static spec_arg_kinds to_spec_arg_kinds(buffer<spec_arg_kind> const & ks) {
spec_arg_kinds r;
unsigned i = ks.size();
while (i > 0) {
--i;
r = spec_arg_kinds(object_ref(box(static_cast<unsigned>(ks[i]))), r);
}
return r;
}
char const * to_str(spec_arg_kind k) {
switch (k) {
case spec_arg_kind::Fixed: return "F";
case spec_arg_kind::FixedInst: return "I";
case spec_arg_kind::Other: return "X";
}
lean_unreachable();
}
class spec_info : public object_ref {
explicit spec_info(b_obj_arg o, bool b):object_ref(o, b) {}
public:
spec_info(names const & ns, spec_arg_kinds ks):
object_ref(mk_cnstr(0, ns, ks)) {}
spec_info():spec_info(names(), spec_arg_kinds()) {}
spec_info(spec_info const & other):object_ref(other) {}
spec_info(spec_info && other):object_ref(other) {}
spec_info & operator=(spec_info const & other) { object_ref::operator=(other); return *this; }
spec_info & operator=(spec_info && other) { object_ref::operator=(other); return *this; }
names const & get_mutual_decls() const { return static_cast<names const &>(cnstr_get_ref(*this, 0)); }
spec_arg_kinds const & get_arg_kinds() const { return static_cast<spec_arg_kinds const &>(cnstr_get_ref(*this, 1)); }
void serialize(serializer & s) const { s.write_object(raw()); }
static spec_info deserialize(deserializer & d) { return spec_info(d.read_object(), true); }
};
serializer & operator<<(serializer & s, spec_info const & si) { si.serialize(s); return s; }
deserializer & operator>>(deserializer & d, spec_info & si) { si = spec_info::deserialize(d); return d; }
/* Information for executing code specialization.
TODO(Leo): use the to be implemented new module system. */
struct specialize_ext : public environment_extension {
name_map<spec_info> m_spec_info;
// TODO(Leo): cache specialization results
};
struct specialize_ext_reg {
unsigned m_ext_id;
specialize_ext_reg() { m_ext_id = environment::register_extension(std::make_shared<specialize_ext>()); }
};
static specialize_ext_reg * g_ext = nullptr;
static specialize_ext const & get_extension(environment const & env) {
return static_cast<specialize_ext const &>(env.get_extension(g_ext->m_ext_id));
}
static environment update(environment const & env, specialize_ext const & ext) {
return env.update(g_ext->m_ext_id, std::make_shared<specialize_ext>(ext));
}
/* Support for old module manager.
Remark: this code will be deleted in the future */
struct spec_info_modification : public modification {
LEAN_MODIFICATION("speci")
name m_name;
spec_info m_spec_info;
spec_info_modification(name const & n, spec_info const & s) : m_name(n), m_spec_info(s) {}
void perform(environment & env) const override {
specialize_ext ext = get_extension(env);
ext.m_spec_info.insert(m_name, m_spec_info);
}
void serialize(serializer & s) const override {
s << m_name << m_spec_info;
}
static std::shared_ptr<modification const> deserialize(deserializer & d) {
name n; spec_info s;
d >> n >> s;
return std::make_shared<spec_info_modification>(n, s);
}
};
typedef buffer<pair<name, buffer<spec_arg_kind>>> spec_info_buffer;
/* We only specialize arguments that are "fixed" in mutual recursive declarations.
The buffer `info_buffer` stores which arguments are fixed for each declaration in a mutual recursive declaration.
This procedure traverses `e` and updates `info_buffer`.
Remark: we only create free variables for the header of each declaration. Then, we assume an argument of a
recursive call is fixed iff it is a free variable (see `update_spec_info`). */
static void update_info_buffer(environment const & env, expr e, name_set const & S, spec_info_buffer & info_buffer) {
while (true) {
switch (e.kind()) {
case expr_kind::Lambda:
e = binding_body(e);
break;
case expr_kind::Let:
update_info_buffer(env, let_value(e), S, info_buffer);
e = let_body(e);
break;
case expr_kind::App:
if (is_cases_on_app(env, e)) {
buffer<expr> args;
expr const & c_fn = get_app_args(e, args);
unsigned minors_begin; unsigned minors_end;
std::tie(minors_begin, minors_end) = get_cases_on_minors_range(env, const_name(c_fn));
for (unsigned i = minors_begin; i < minors_end; i++) {
update_info_buffer(env, args[i], S, info_buffer);
}
} else {
buffer<expr> args;
expr const & fn = get_app_args(e, args);
if (is_constant(fn) && S.contains(const_name(fn))) {
for (auto & entry : info_buffer) {
if (entry.first == const_name(fn)) {
unsigned sz = entry.second.size();
for (unsigned i = 0; i < sz; i++) {
if (i >= args.size() || !is_fvar(args[i])) {
entry.second[i] = spec_arg_kind::Other;
}
}
break;
}
}
}
}
return;
default:
return;
}
}
}
environment update_spec_info(environment const & env, comp_decls const & ds) {
name_set S;
spec_info_buffer d_infos;
/* Initialzie d_infos and S */
for (comp_decl const & d : ds) {
S.insert(d.fst());
d_infos.push_back(pair<name, buffer<spec_arg_kind>>());
auto & info = d_infos.back();
info.first = d.fst();
expr code = d.snd();
while (is_lambda(code)) {
if (is_inst_implicit(binding_info(code)))
info.second.push_back(spec_arg_kind::FixedInst);
else
info.second.push_back(spec_arg_kind::Fixed);
code = binding_body(code);
}
}
/* Update d_infos */
name x("_x");
for (comp_decl const & d : ds) {
buffer<expr> fvars;
expr code = d.snd();
unsigned i = 1;
/* Create free variables for header variables. */
while (is_lambda(code)) {
fvars.push_back(mk_fvar(name(x, i)));
code = binding_body(code);
}
code = instantiate_rev(code, fvars.size(), fvars.data());
update_info_buffer(env, code, S, d_infos);
}
/* Update extension */
environment new_env = env;
specialize_ext ext = get_extension(env);
names mutual_decls = map2<name>(ds, [&](comp_decl const & d) { return d.fst(); });
for (pair<name, buffer<spec_arg_kind>> const & info : d_infos) {
name const & n = info.first;
spec_info si(mutual_decls, to_spec_arg_kinds(info.second));
lean_trace(name({"compiler", "spec_info"}), tout() << n;
for (spec_arg_kind k : info.second) {
tout() << " " << to_str(k);
}
tout() << "\n";);
new_env = module::add(new_env, std::make_shared<spec_info_modification>(n, si));
ext.m_spec_info.insert(n, si);
}
return update(new_env, ext);
}
class specialize_fn {
type_checker::state m_st;
local_ctx m_lctx;
@ -95,13 +286,29 @@ public:
}
};
pair<environment, comp_decls> specialize(environment const & env, comp_decl const & d) {
pair<environment, comp_decls> specialize_core(environment const & env, comp_decl const & d) {
return specialize_fn(env)(d);
}
pair<environment, comp_decls> specialize(environment env, comp_decls const & ds) {
env = update_spec_info(env, ds);
comp_decls r;
for (comp_decl const & d : ds) {
comp_decls new_ds;
std::tie(env, new_ds) = specialize_core(env, d);
r = append(r, new_ds);
}
return mk_pair(env, r);
}
void initialize_specialize() {
g_ext = new specialize_ext_reg();
spec_info_modification::init();
register_trace_class({"compiler", "spec_info"});
}
void finalize_specialize() {
spec_info_modification::finalize();
delete g_ext;
}
}

View file

@ -8,7 +8,7 @@ Author: Leonardo de Moura
#include "kernel/environment.h"
#include "library/compiler/util.h"
namespace lean {
pair<environment, comp_decls> specialize(environment const & env, comp_decl const & d);
pair<environment, comp_decls> specialize(environment env, comp_decls const & ds);
void initialize_specialize();
void finalize_specialize();
}