chore(library/compiler/specialize): cleanup
Preparing to implement environment extension in Lean.
This commit is contained in:
parent
cf6f6bc96d
commit
3c0caee73b
1 changed files with 38 additions and 13 deletions
|
|
@ -153,6 +153,21 @@ struct spec_info_modification : public modification {
|
|||
}
|
||||
};
|
||||
|
||||
static environment save_specialization_info(environment const & env, name const & fn, spec_info const & si) {
|
||||
specialize_ext ext = get_extension(env);
|
||||
ext.m_spec_info.insert(fn, si);
|
||||
environment new_env = update(env, ext);
|
||||
return module::add(new_env, new spec_info_modification(fn, si));
|
||||
}
|
||||
|
||||
static optional<spec_info> get_specialization_info(environment const & env, name const & fn) {
|
||||
if (spec_info const * info = get_extension(env).m_spec_info.find(fn)) {
|
||||
return optional<spec_info>(*info);
|
||||
} else {
|
||||
return optional<spec_info>();
|
||||
}
|
||||
}
|
||||
|
||||
typedef buffer<pair<name, buffer<spec_arg_kind>>> spec_info_buffer;
|
||||
|
||||
/* We only specialize arguments that are "fixed" in mutual recursive declarations.
|
||||
|
|
@ -262,7 +277,6 @@ environment update_spec_info(environment const & env, comp_decls const & ds) {
|
|||
}
|
||||
/* Update extension */
|
||||
environment new_env = env;
|
||||
specialize_ext ext = get_extension(env);
|
||||
names mutual_decls = map2<name>(ds, [&](comp_decl const & d) { return d.fst(); });
|
||||
for (pair<name, buffer<spec_arg_kind>> const & info : d_infos) {
|
||||
name const & n = info.first;
|
||||
|
|
@ -272,10 +286,9 @@ environment update_spec_info(environment const & env, comp_decls const & ds) {
|
|||
tout() << " " << to_str(k);
|
||||
}
|
||||
tout() << "\n";);
|
||||
new_env = module::add(new_env, new spec_info_modification(n, si));
|
||||
ext.m_spec_info.insert(n, si);
|
||||
new_env = save_specialization_info(new_env, n, si);
|
||||
}
|
||||
return update(new_env, ext);
|
||||
return new_env;
|
||||
}
|
||||
|
||||
/* Support for old module manager.
|
||||
|
|
@ -305,10 +318,24 @@ struct spec_cache_modification : public modification {
|
|||
}
|
||||
};
|
||||
|
||||
static environment cache_specialization(environment const & env, expr const & k, name const & fn) {
|
||||
specialize_ext ext = get_extension(env);
|
||||
ext.m_cache.insert(k, fn);
|
||||
environment new_env = update(env, ext);
|
||||
return module::add(new_env, new spec_cache_modification(k, fn));
|
||||
}
|
||||
|
||||
static optional<name> get_cached_specialization(environment const & env, expr const & e) {
|
||||
if (name const * it = get_extension(env).m_cache.find(e)) {
|
||||
return optional<name>(*it);
|
||||
} else {
|
||||
return optional<name>();
|
||||
}
|
||||
}
|
||||
|
||||
class specialize_fn {
|
||||
type_checker::state m_st;
|
||||
csimp_cfg m_cfg;
|
||||
specialize_ext m_ext;
|
||||
local_ctx m_lctx;
|
||||
buffer<comp_decl> m_new_decls;
|
||||
name m_base_name;
|
||||
|
|
@ -404,7 +431,7 @@ class specialize_fn {
|
|||
};
|
||||
|
||||
void get_arg_kinds(name const & fn, buffer<spec_arg_kind> & kinds) {
|
||||
spec_info const * info = m_ext.m_spec_info.find(fn);
|
||||
optional<spec_info> info = get_specialization_info(env(), fn);
|
||||
lean_assert(info);
|
||||
to_buffer(info->get_arg_kinds(), kinds);
|
||||
}
|
||||
|
|
@ -1032,7 +1059,7 @@ class specialize_fn {
|
|||
expr key;
|
||||
if (gcache_enabled) {
|
||||
key = mk_app(fn, gcache_key_args);
|
||||
if (name const * it = m_ext.m_cache.find(key))
|
||||
if (optional<name> it = get_cached_specialization(env(), key))
|
||||
new_fn_name = *it;
|
||||
}
|
||||
if (!new_fn_name) {
|
||||
|
|
@ -1044,8 +1071,7 @@ class specialize_fn {
|
|||
mk_new_decl(pre_decl, fvars, fvar_vals, ctx);
|
||||
}
|
||||
if (gcache_enabled) {
|
||||
m_ext.m_cache.insert(key, *new_fn_name);
|
||||
m_st.env() = module::add(env(), new spec_cache_modification(key, *new_fn_name));
|
||||
m_st.env() = cache_specialization(env(), key, *new_fn_name);
|
||||
}
|
||||
}
|
||||
expr r = mk_constant(*new_fn_name);
|
||||
|
|
@ -1071,7 +1097,7 @@ class specialize_fn {
|
|||
|| (is_instance(env(), const_name(fn)) && !has_specialize_attribute(env(), const_name(fn)))) {
|
||||
return e;
|
||||
}
|
||||
spec_info const * info = m_ext.m_spec_info.find(const_name(fn));
|
||||
optional<spec_info> info = get_specialization_info(env(), const_name(fn));
|
||||
if (!info) return e;
|
||||
spec_ctx ctx;
|
||||
ctx.m_mutual = info->get_mutual_decls();
|
||||
|
|
@ -1093,15 +1119,14 @@ class specialize_fn {
|
|||
|
||||
public:
|
||||
specialize_fn(environment const & env, csimp_cfg const & cfg):
|
||||
m_st(env), m_cfg(cfg), m_ext(get_extension(env)), m_at("_at"), m_spec("_spec") {}
|
||||
m_st(env), m_cfg(cfg), m_at("_at"), m_spec("_spec") {}
|
||||
|
||||
pair<environment, comp_decls> operator()(comp_decl const & d) {
|
||||
m_base_name = d.fst();
|
||||
lean_trace(name({"compiler", "specialize"}), tout() << "INPUT: " << d.fst() << "\n" << d.snd() << "\n";);
|
||||
expr new_v = visit(d.snd());
|
||||
comp_decl new_d(d.fst(), new_v);
|
||||
environment new_env = update(env(), m_ext);
|
||||
return mk_pair(new_env, append(comp_decls(m_new_decls), comp_decls(new_d)));
|
||||
return mk_pair(env(), append(comp_decls(m_new_decls), comp_decls(new_d)));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue