chore(library/compiler/specialize): cleanup

Preparing to implement environment extension in Lean.
This commit is contained in:
Leonardo de Moura 2019-06-27 09:32:25 -07:00
parent cf6f6bc96d
commit 3c0caee73b

View file

@ -153,6 +153,21 @@ struct spec_info_modification : public modification {
}
};
static environment save_specialization_info(environment const & env, name const & fn, spec_info const & si) {
specialize_ext ext = get_extension(env);
ext.m_spec_info.insert(fn, si);
environment new_env = update(env, ext);
return module::add(new_env, new spec_info_modification(fn, si));
}
static optional<spec_info> get_specialization_info(environment const & env, name const & fn) {
if (spec_info const * info = get_extension(env).m_spec_info.find(fn)) {
return optional<spec_info>(*info);
} else {
return optional<spec_info>();
}
}
typedef buffer<pair<name, buffer<spec_arg_kind>>> spec_info_buffer;
/* We only specialize arguments that are "fixed" in mutual recursive declarations.
@ -262,7 +277,6 @@ environment update_spec_info(environment const & env, comp_decls const & ds) {
}
/* Update extension */
environment new_env = env;
specialize_ext ext = get_extension(env);
names mutual_decls = map2<name>(ds, [&](comp_decl const & d) { return d.fst(); });
for (pair<name, buffer<spec_arg_kind>> const & info : d_infos) {
name const & n = info.first;
@ -272,10 +286,9 @@ environment update_spec_info(environment const & env, comp_decls const & ds) {
tout() << " " << to_str(k);
}
tout() << "\n";);
new_env = module::add(new_env, new spec_info_modification(n, si));
ext.m_spec_info.insert(n, si);
new_env = save_specialization_info(new_env, n, si);
}
return update(new_env, ext);
return new_env;
}
/* Support for old module manager.
@ -305,10 +318,24 @@ struct spec_cache_modification : public modification {
}
};
static environment cache_specialization(environment const & env, expr const & k, name const & fn) {
specialize_ext ext = get_extension(env);
ext.m_cache.insert(k, fn);
environment new_env = update(env, ext);
return module::add(new_env, new spec_cache_modification(k, fn));
}
static optional<name> get_cached_specialization(environment const & env, expr const & e) {
if (name const * it = get_extension(env).m_cache.find(e)) {
return optional<name>(*it);
} else {
return optional<name>();
}
}
class specialize_fn {
type_checker::state m_st;
csimp_cfg m_cfg;
specialize_ext m_ext;
local_ctx m_lctx;
buffer<comp_decl> m_new_decls;
name m_base_name;
@ -404,7 +431,7 @@ class specialize_fn {
};
void get_arg_kinds(name const & fn, buffer<spec_arg_kind> & kinds) {
spec_info const * info = m_ext.m_spec_info.find(fn);
optional<spec_info> info = get_specialization_info(env(), fn);
lean_assert(info);
to_buffer(info->get_arg_kinds(), kinds);
}
@ -1032,7 +1059,7 @@ class specialize_fn {
expr key;
if (gcache_enabled) {
key = mk_app(fn, gcache_key_args);
if (name const * it = m_ext.m_cache.find(key))
if (optional<name> it = get_cached_specialization(env(), key))
new_fn_name = *it;
}
if (!new_fn_name) {
@ -1044,8 +1071,7 @@ class specialize_fn {
mk_new_decl(pre_decl, fvars, fvar_vals, ctx);
}
if (gcache_enabled) {
m_ext.m_cache.insert(key, *new_fn_name);
m_st.env() = module::add(env(), new spec_cache_modification(key, *new_fn_name));
m_st.env() = cache_specialization(env(), key, *new_fn_name);
}
}
expr r = mk_constant(*new_fn_name);
@ -1071,7 +1097,7 @@ class specialize_fn {
|| (is_instance(env(), const_name(fn)) && !has_specialize_attribute(env(), const_name(fn)))) {
return e;
}
spec_info const * info = m_ext.m_spec_info.find(const_name(fn));
optional<spec_info> info = get_specialization_info(env(), const_name(fn));
if (!info) return e;
spec_ctx ctx;
ctx.m_mutual = info->get_mutual_decls();
@ -1093,15 +1119,14 @@ class specialize_fn {
public:
specialize_fn(environment const & env, csimp_cfg const & cfg):
m_st(env), m_cfg(cfg), m_ext(get_extension(env)), m_at("_at"), m_spec("_spec") {}
m_st(env), m_cfg(cfg), m_at("_at"), m_spec("_spec") {}
pair<environment, comp_decls> operator()(comp_decl const & d) {
m_base_name = d.fst();
lean_trace(name({"compiler", "specialize"}), tout() << "INPUT: " << d.fst() << "\n" << d.snd() << "\n";);
expr new_v = visit(d.snd());
comp_decl new_d(d.fst(), new_v);
environment new_env = update(env(), m_ext);
return mk_pair(new_env, append(comp_decls(m_new_decls), comp_decls(new_d)));
return mk_pair(env(), append(comp_decls(m_new_decls), comp_decls(new_d)));
}
};