From 3c0caee73b168306d5230bfe717748ad2eccc17f Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 27 Jun 2019 09:32:25 -0700 Subject: [PATCH] chore(library/compiler/specialize): cleanup Preparing to implement environment extension in Lean. --- src/library/compiler/specialize.cpp | 51 +++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/src/library/compiler/specialize.cpp b/src/library/compiler/specialize.cpp index b49c208d2d..65559b281e 100644 --- a/src/library/compiler/specialize.cpp +++ b/src/library/compiler/specialize.cpp @@ -153,6 +153,21 @@ struct spec_info_modification : public modification { } }; +static environment save_specialization_info(environment const & env, name const & fn, spec_info const & si) { + specialize_ext ext = get_extension(env); + ext.m_spec_info.insert(fn, si); + environment new_env = update(env, ext); + return module::add(new_env, new spec_info_modification(fn, si)); +} + +static optional get_specialization_info(environment const & env, name const & fn) { + if (spec_info const * info = get_extension(env).m_spec_info.find(fn)) { + return optional(*info); + } else { + return optional(); + } +} + typedef buffer>> spec_info_buffer; /* We only specialize arguments that are "fixed" in mutual recursive declarations. @@ -262,7 +277,6 @@ environment update_spec_info(environment const & env, comp_decls const & ds) { } /* Update extension */ environment new_env = env; - specialize_ext ext = get_extension(env); names mutual_decls = map2(ds, [&](comp_decl const & d) { return d.fst(); }); for (pair> const & info : d_infos) { name const & n = info.first; @@ -272,10 +286,9 @@ environment update_spec_info(environment const & env, comp_decls const & ds) { tout() << " " << to_str(k); } tout() << "\n";); - new_env = module::add(new_env, new spec_info_modification(n, si)); - ext.m_spec_info.insert(n, si); + new_env = save_specialization_info(new_env, n, si); } - return update(new_env, ext); + return new_env; } /* Support for old module manager. @@ -305,10 +318,24 @@ struct spec_cache_modification : public modification { } }; +static environment cache_specialization(environment const & env, expr const & k, name const & fn) { + specialize_ext ext = get_extension(env); + ext.m_cache.insert(k, fn); + environment new_env = update(env, ext); + return module::add(new_env, new spec_cache_modification(k, fn)); +} + +static optional get_cached_specialization(environment const & env, expr const & e) { + if (name const * it = get_extension(env).m_cache.find(e)) { + return optional(*it); + } else { + return optional(); + } +} + class specialize_fn { type_checker::state m_st; csimp_cfg m_cfg; - specialize_ext m_ext; local_ctx m_lctx; buffer m_new_decls; name m_base_name; @@ -404,7 +431,7 @@ class specialize_fn { }; void get_arg_kinds(name const & fn, buffer & kinds) { - spec_info const * info = m_ext.m_spec_info.find(fn); + optional info = get_specialization_info(env(), fn); lean_assert(info); to_buffer(info->get_arg_kinds(), kinds); } @@ -1032,7 +1059,7 @@ class specialize_fn { expr key; if (gcache_enabled) { key = mk_app(fn, gcache_key_args); - if (name const * it = m_ext.m_cache.find(key)) + if (optional it = get_cached_specialization(env(), key)) new_fn_name = *it; } if (!new_fn_name) { @@ -1044,8 +1071,7 @@ class specialize_fn { mk_new_decl(pre_decl, fvars, fvar_vals, ctx); } if (gcache_enabled) { - m_ext.m_cache.insert(key, *new_fn_name); - m_st.env() = module::add(env(), new spec_cache_modification(key, *new_fn_name)); + m_st.env() = cache_specialization(env(), key, *new_fn_name); } } expr r = mk_constant(*new_fn_name); @@ -1071,7 +1097,7 @@ class specialize_fn { || (is_instance(env(), const_name(fn)) && !has_specialize_attribute(env(), const_name(fn)))) { return e; } - spec_info const * info = m_ext.m_spec_info.find(const_name(fn)); + optional info = get_specialization_info(env(), const_name(fn)); if (!info) return e; spec_ctx ctx; ctx.m_mutual = info->get_mutual_decls(); @@ -1093,15 +1119,14 @@ class specialize_fn { public: specialize_fn(environment const & env, csimp_cfg const & cfg): - m_st(env), m_cfg(cfg), m_ext(get_extension(env)), m_at("_at"), m_spec("_spec") {} + m_st(env), m_cfg(cfg), m_at("_at"), m_spec("_spec") {} pair operator()(comp_decl const & d) { m_base_name = d.fst(); lean_trace(name({"compiler", "specialize"}), tout() << "INPUT: " << d.fst() << "\n" << d.snd() << "\n";); expr new_v = visit(d.snd()); comp_decl new_d(d.fst(), new_v); - environment new_env = update(env(), m_ext); - return mk_pair(new_env, append(comp_decls(m_new_decls), comp_decls(new_d))); + return mk_pair(env(), append(comp_decls(m_new_decls), comp_decls(new_d))); } };