feat(library/compiler/specialize): add spec_info
Store which arguments can be specialized.
This commit is contained in:
parent
10b99a678c
commit
9ca4c362ae
3 changed files with 209 additions and 12 deletions
|
|
@ -56,16 +56,6 @@ comp_decls apply(F && f, comp_decls const & ds) {
|
|||
return map(ds, [&](comp_decl const & d) { return comp_decl(d.fst(), f(d.snd())); });
|
||||
}
|
||||
|
||||
static pair<environment, comp_decls> specialize(environment env, comp_decls const & ds) {
|
||||
comp_decls r;
|
||||
for (comp_decl const & d : ds) {
|
||||
comp_decls new_ds;
|
||||
std::tie(env, new_ds) = specialize(env, d);
|
||||
r = append(r, new_ds);
|
||||
}
|
||||
return mk_pair(env, r);
|
||||
}
|
||||
|
||||
static comp_decls lambda_lifting(environment const & env, comp_decls const & ds) {
|
||||
comp_decls r;
|
||||
for (comp_decl const & d : ds) {
|
||||
|
|
|
|||
|
|
@ -6,11 +6,202 @@ Author: Leonardo de Moura
|
|||
*/
|
||||
#include "runtime/flet.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "library/module.h"
|
||||
#include "library/compiler/util.h"
|
||||
|
||||
#include "library/trace.h"
|
||||
|
||||
namespace lean {
|
||||
enum class spec_arg_kind { Fixed, FixedInst, Other };
|
||||
static spec_arg_kind to_spec_arg_kind(object_ref const & r) {
|
||||
lean_assert(is_scalar(r)); return static_cast<spec_arg_kind>(unbox(r.raw()));
|
||||
}
|
||||
typedef objects spec_arg_kinds;
|
||||
static spec_arg_kinds to_spec_arg_kinds(buffer<spec_arg_kind> const & ks) {
|
||||
spec_arg_kinds r;
|
||||
unsigned i = ks.size();
|
||||
while (i > 0) {
|
||||
--i;
|
||||
r = spec_arg_kinds(object_ref(box(static_cast<unsigned>(ks[i]))), r);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
char const * to_str(spec_arg_kind k) {
|
||||
switch (k) {
|
||||
case spec_arg_kind::Fixed: return "F";
|
||||
case spec_arg_kind::FixedInst: return "I";
|
||||
case spec_arg_kind::Other: return "X";
|
||||
}
|
||||
lean_unreachable();
|
||||
}
|
||||
|
||||
class spec_info : public object_ref {
|
||||
explicit spec_info(b_obj_arg o, bool b):object_ref(o, b) {}
|
||||
public:
|
||||
spec_info(names const & ns, spec_arg_kinds ks):
|
||||
object_ref(mk_cnstr(0, ns, ks)) {}
|
||||
spec_info():spec_info(names(), spec_arg_kinds()) {}
|
||||
spec_info(spec_info const & other):object_ref(other) {}
|
||||
spec_info(spec_info && other):object_ref(other) {}
|
||||
spec_info & operator=(spec_info const & other) { object_ref::operator=(other); return *this; }
|
||||
spec_info & operator=(spec_info && other) { object_ref::operator=(other); return *this; }
|
||||
names const & get_mutual_decls() const { return static_cast<names const &>(cnstr_get_ref(*this, 0)); }
|
||||
spec_arg_kinds const & get_arg_kinds() const { return static_cast<spec_arg_kinds const &>(cnstr_get_ref(*this, 1)); }
|
||||
void serialize(serializer & s) const { s.write_object(raw()); }
|
||||
static spec_info deserialize(deserializer & d) { return spec_info(d.read_object(), true); }
|
||||
};
|
||||
|
||||
serializer & operator<<(serializer & s, spec_info const & si) { si.serialize(s); return s; }
|
||||
deserializer & operator>>(deserializer & d, spec_info & si) { si = spec_info::deserialize(d); return d; }
|
||||
|
||||
/* Information for executing code specialization.
|
||||
TODO(Leo): use the to be implemented new module system. */
|
||||
struct specialize_ext : public environment_extension {
|
||||
name_map<spec_info> m_spec_info;
|
||||
// TODO(Leo): cache specialization results
|
||||
};
|
||||
|
||||
struct specialize_ext_reg {
|
||||
unsigned m_ext_id;
|
||||
specialize_ext_reg() { m_ext_id = environment::register_extension(std::make_shared<specialize_ext>()); }
|
||||
};
|
||||
|
||||
static specialize_ext_reg * g_ext = nullptr;
|
||||
static specialize_ext const & get_extension(environment const & env) {
|
||||
return static_cast<specialize_ext const &>(env.get_extension(g_ext->m_ext_id));
|
||||
}
|
||||
static environment update(environment const & env, specialize_ext const & ext) {
|
||||
return env.update(g_ext->m_ext_id, std::make_shared<specialize_ext>(ext));
|
||||
}
|
||||
|
||||
/* Support for old module manager.
|
||||
Remark: this code will be deleted in the future */
|
||||
struct spec_info_modification : public modification {
|
||||
LEAN_MODIFICATION("speci")
|
||||
|
||||
name m_name;
|
||||
spec_info m_spec_info;
|
||||
|
||||
spec_info_modification(name const & n, spec_info const & s) : m_name(n), m_spec_info(s) {}
|
||||
|
||||
void perform(environment & env) const override {
|
||||
specialize_ext ext = get_extension(env);
|
||||
ext.m_spec_info.insert(m_name, m_spec_info);
|
||||
}
|
||||
|
||||
void serialize(serializer & s) const override {
|
||||
s << m_name << m_spec_info;
|
||||
}
|
||||
|
||||
static std::shared_ptr<modification const> deserialize(deserializer & d) {
|
||||
name n; spec_info s;
|
||||
d >> n >> s;
|
||||
return std::make_shared<spec_info_modification>(n, s);
|
||||
}
|
||||
};
|
||||
|
||||
typedef buffer<pair<name, buffer<spec_arg_kind>>> spec_info_buffer;
|
||||
|
||||
/* We only specialize arguments that are "fixed" in mutual recursive declarations.
|
||||
The buffer `info_buffer` stores which arguments are fixed for each declaration in a mutual recursive declaration.
|
||||
This procedure traverses `e` and updates `info_buffer`.
|
||||
|
||||
Remark: we only create free variables for the header of each declaration. Then, we assume an argument of a
|
||||
recursive call is fixed iff it is a free variable (see `update_spec_info`). */
|
||||
static void update_info_buffer(environment const & env, expr e, name_set const & S, spec_info_buffer & info_buffer) {
|
||||
while (true) {
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Lambda:
|
||||
e = binding_body(e);
|
||||
break;
|
||||
case expr_kind::Let:
|
||||
update_info_buffer(env, let_value(e), S, info_buffer);
|
||||
e = let_body(e);
|
||||
break;
|
||||
case expr_kind::App:
|
||||
if (is_cases_on_app(env, e)) {
|
||||
buffer<expr> args;
|
||||
expr const & c_fn = get_app_args(e, args);
|
||||
unsigned minors_begin; unsigned minors_end;
|
||||
std::tie(minors_begin, minors_end) = get_cases_on_minors_range(env, const_name(c_fn));
|
||||
for (unsigned i = minors_begin; i < minors_end; i++) {
|
||||
update_info_buffer(env, args[i], S, info_buffer);
|
||||
}
|
||||
} else {
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(e, args);
|
||||
if (is_constant(fn) && S.contains(const_name(fn))) {
|
||||
for (auto & entry : info_buffer) {
|
||||
if (entry.first == const_name(fn)) {
|
||||
unsigned sz = entry.second.size();
|
||||
for (unsigned i = 0; i < sz; i++) {
|
||||
if (i >= args.size() || !is_fvar(args[i])) {
|
||||
entry.second[i] = spec_arg_kind::Other;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
environment update_spec_info(environment const & env, comp_decls const & ds) {
|
||||
name_set S;
|
||||
spec_info_buffer d_infos;
|
||||
/* Initialzie d_infos and S */
|
||||
for (comp_decl const & d : ds) {
|
||||
S.insert(d.fst());
|
||||
d_infos.push_back(pair<name, buffer<spec_arg_kind>>());
|
||||
auto & info = d_infos.back();
|
||||
info.first = d.fst();
|
||||
expr code = d.snd();
|
||||
while (is_lambda(code)) {
|
||||
if (is_inst_implicit(binding_info(code)))
|
||||
info.second.push_back(spec_arg_kind::FixedInst);
|
||||
else
|
||||
info.second.push_back(spec_arg_kind::Fixed);
|
||||
code = binding_body(code);
|
||||
}
|
||||
}
|
||||
/* Update d_infos */
|
||||
name x("_x");
|
||||
for (comp_decl const & d : ds) {
|
||||
buffer<expr> fvars;
|
||||
expr code = d.snd();
|
||||
unsigned i = 1;
|
||||
/* Create free variables for header variables. */
|
||||
while (is_lambda(code)) {
|
||||
fvars.push_back(mk_fvar(name(x, i)));
|
||||
code = binding_body(code);
|
||||
}
|
||||
code = instantiate_rev(code, fvars.size(), fvars.data());
|
||||
update_info_buffer(env, code, S, d_infos);
|
||||
}
|
||||
/* Update extension */
|
||||
environment new_env = env;
|
||||
specialize_ext ext = get_extension(env);
|
||||
names mutual_decls = map2<name>(ds, [&](comp_decl const & d) { return d.fst(); });
|
||||
for (pair<name, buffer<spec_arg_kind>> const & info : d_infos) {
|
||||
name const & n = info.first;
|
||||
spec_info si(mutual_decls, to_spec_arg_kinds(info.second));
|
||||
lean_trace(name({"compiler", "spec_info"}), tout() << n;
|
||||
for (spec_arg_kind k : info.second) {
|
||||
tout() << " " << to_str(k);
|
||||
}
|
||||
tout() << "\n";);
|
||||
new_env = module::add(new_env, std::make_shared<spec_info_modification>(n, si));
|
||||
ext.m_spec_info.insert(n, si);
|
||||
}
|
||||
return update(new_env, ext);
|
||||
}
|
||||
|
||||
class specialize_fn {
|
||||
type_checker::state m_st;
|
||||
local_ctx m_lctx;
|
||||
|
|
@ -95,13 +286,29 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
pair<environment, comp_decls> specialize(environment const & env, comp_decl const & d) {
|
||||
pair<environment, comp_decls> specialize_core(environment const & env, comp_decl const & d) {
|
||||
return specialize_fn(env)(d);
|
||||
}
|
||||
|
||||
pair<environment, comp_decls> specialize(environment env, comp_decls const & ds) {
|
||||
env = update_spec_info(env, ds);
|
||||
comp_decls r;
|
||||
for (comp_decl const & d : ds) {
|
||||
comp_decls new_ds;
|
||||
std::tie(env, new_ds) = specialize_core(env, d);
|
||||
r = append(r, new_ds);
|
||||
}
|
||||
return mk_pair(env, r);
|
||||
}
|
||||
|
||||
void initialize_specialize() {
|
||||
g_ext = new specialize_ext_reg();
|
||||
spec_info_modification::init();
|
||||
register_trace_class({"compiler", "spec_info"});
|
||||
}
|
||||
|
||||
void finalize_specialize() {
|
||||
spec_info_modification::finalize();
|
||||
delete g_ext;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ Author: Leonardo de Moura
|
|||
#include "kernel/environment.h"
|
||||
#include "library/compiler/util.h"
|
||||
namespace lean {
|
||||
pair<environment, comp_decls> specialize(environment const & env, comp_decl const & d);
|
||||
pair<environment, comp_decls> specialize(environment env, comp_decls const & ds);
|
||||
void initialize_specialize();
|
||||
void finalize_specialize();
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue