diff --git a/src/library/fun_info_manager.cpp b/src/library/fun_info_manager.cpp index 94b92f0f49..d078d3aa02 100644 --- a/src/library/fun_info_manager.cpp +++ b/src/library/fun_info_manager.cpp @@ -5,6 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include +#include #include "kernel/for_each_fn.h" #include "kernel/instantiate.h" #include "kernel/abstract.h" @@ -33,13 +34,14 @@ list fun_info_manager::collect_deps(expr const & type, buffer co return to_list(deps); } -fun_info fun_info_manager::get(expr const & e) { - if (auto r = m_fun_info.find(e)) - return *r; - expr type = m_ctx.relaxed_try_to_pi(m_ctx.infer(e)); - buffer info; +/* Store parameter info for fn in \c pinfos and return the dependencies of the resulting type. */ +list fun_info_manager::get_core(expr const & fn, buffer & pinfos, unsigned max_args) { + expr type = m_ctx.relaxed_try_to_pi(m_ctx.infer(fn)); buffer locals; + unsigned i = 0; while (is_pi(type)) { + if (i == max_args) + break; expr local = m_ctx.mk_tmp_local_from_binding(type); expr local_type = m_ctx.infer(local); expr new_type = m_ctx.relaxed_try_to_pi(instantiate(binding_body(type), local)); @@ -51,41 +53,152 @@ fun_info fun_info_manager::get(expr const & e) { // TODO(Leo): check if the following line is a performance bottleneck. is_sub = static_cast(m_ctx.mk_subsingleton_instance(local_type)); } - info.emplace_back(spec, - binding_info(type).is_implicit(), - binding_info(type).is_inst_implicit(), - is_prop, is_sub, is_dep, collect_deps(local_type, locals)); + pinfos.emplace_back(spec, + binding_info(type).is_implicit(), + binding_info(type).is_inst_implicit(), + is_prop, is_sub, is_dep, collect_deps(local_type, locals)); locals.push_back(local); type = new_type; + i++; } - fun_info r(info.size(), to_list(info), collect_deps(type, locals)); - m_fun_info.insert(e, r); + return collect_deps(type, locals); +} + +fun_info fun_info_manager::get(expr const & e) { + if (auto r = m_cache_get.find(e)) + return *r; + buffer pinfos; + auto result_deps = get_core(e, pinfos, std::numeric_limits::max()); + fun_info r(pinfos.size(), to_list(pinfos), result_deps); + m_cache_get.insert(e, r); return r; } fun_info fun_info_manager::get(expr const & e, unsigned nargs) { - auto r = get(e); - lean_assert(nargs <= r.get_arity()); - if (nargs == r.get_arity()) { - return r; - } else { - buffer pinfos; - to_buffer(r.get_params_info(), pinfos); - buffer rdeps; - to_buffer(r.get_result_dependencies(), rdeps); - for (unsigned i = nargs; i < pinfos.size(); i++) { - for (auto d : pinfos[i].get_dependencies()) { - if (std::find(rdeps.begin(), rdeps.end(), d) == rdeps.end()) - rdeps.push_back(d); - } + if (auto r = m_cache_get_nargs.find(mk_pair(nargs, e))) + return *r; + buffer pinfos; + auto result_deps = get_core(e, pinfos, nargs); + fun_info r(pinfos.size(), to_list(pinfos), result_deps); + m_cache_get_nargs.insert(mk_pair(nargs, e), r); + return r; +} + +/* Return true if there is j s.t. pinfos[j] is not a + proposition/subsingleton and it dependends of argument i */ +static bool has_nonprop_nonsubsingleton_fwd_dep(unsigned i, buffer const & pinfos) { + for (unsigned j = i+1; j < pinfos.size(); j++) { + param_info const & fwd_pinfo = pinfos[j]; + if (fwd_pinfo.is_prop() || fwd_pinfo.is_subsingleton()) + continue; + auto const & fwd_deps = fwd_pinfo.get_dependencies(); + if (std::find(fwd_deps.begin(), fwd_deps.end(), i) == fwd_deps.end()) { + return true; } - pinfos.shrink(nargs); - return fun_info(nargs, to_list(pinfos), to_list(rdeps)); + } + return false; +} + +fun_info fun_info_manager::get_specialization(expr const & fn, buffer const & args, buffer const & pinfos, list const & result_deps) { + buffer new_pinfos; + expr type = m_ctx.relaxed_try_to_pi(m_ctx.infer(fn)); + for (unsigned i = 0; i < args.size(); i++) { + expr new_type = m_ctx.relaxed_try_to_pi(instantiate(binding_body(type), args[i])); + expr arg_type = binding_domain(type); + param_info new_pinfo = pinfos[i]; + new_pinfo.m_specialized = true; + if (!new_pinfo.m_prop) { + new_pinfo.m_prop = m_ctx.is_prop(arg_type); + new_pinfo.m_subsingleton = new_pinfo.m_prop; + } + if (!new_pinfo.m_subsingleton) { + new_pinfo.m_subsingleton = static_cast(m_ctx.mk_subsingleton_instance(arg_type)); + } + new_pinfos.push_back(new_pinfo); + type = new_type; + } + bool spec = true; + return fun_info(new_pinfos.size(), to_list(new_pinfos), result_deps, spec); +} + +/* Copy the first prefix_sz entries from pinfos to new_pinfos and mark them as m_specialized = true */ +static void copy_prefix(unsigned prefix_sz, buffer const & pinfos, buffer & new_pinfos) { + for (unsigned i = 0; i < prefix_sz; i++) { + new_pinfos.push_back(pinfos[i].mk_specialized()); } } -fun_info fun_info_manager::get_specialization(expr const &) { - // TODO(Leo) - lean_unreachable(); +fun_info fun_info_manager::get_specialization(expr const & a) { + lean_assert(is_app(a)); + buffer args; + expr const & fn = get_app_args(a, args); + fun_info info = get(fn, args.size()); + /* + We say info is "cheap" if it is of the form: + + a) 0 or more dependent parameters p s.t. there is at least one forward dependency x : C[p] + which is not a proposition nor a subsingleton. + + b) followed by 0 or more nondependent parameter and/or a dependent parameter + s.t. all forward dependencies are propositions and subsingletons. + + We have a caching mechanism for the "cheap" case. + The cheap case cover many commonly used functions + + eq : Pi {A : Type} (x y : A), Prop + add : Pi {A : Type} [s : has_add A] (x y : A), A + inv : Pi {A : Type} [s : has_inv A] (x : A) (h : invertible x), A + + but it doesn't cover + + p : Pi {A : Type} (x : A) {B : Type} (y : B), Prop + + I don't think this is a big deal since we can write it as: + + p : Pi {A : Type} {B : Type} (x : A) (y : B), Prop + */ + buffer pinfos; + to_buffer(info.get_params_info(), pinfos); + /* Compute "prefix": 0 or more parameters s.t. + at lest one forward dependency is not a proposition or a subsingleton */ + unsigned i = 0; + for (; i < pinfos.size(); i++) { + param_info const & pinfo = pinfos[i]; + if (!pinfo.is_dep()) + break; + /* search for forward dependency that is not a proposition nor a subsingleton */ + if (!has_nonprop_nonsubsingleton_fwd_dep(i, pinfos)) + break; + } + unsigned prefix_sz = i; + /* Check if all remaining arguments are nondependent or + dependent (but all forward dependencies are propositions or subsingletons) */ + for (; i < pinfos.size(); i++) { + param_info const & pinfo = pinfos[i]; + if (!pinfo.is_dep()) + continue; /* nondependent argument */ + if (has_nonprop_nonsubsingleton_fwd_dep(i, pinfos)) + break; /* failed i-th argument has a forward dependent that is not a prop nor a subsingleton */ + } + if (i < pinfos.size()) { + /* Expensive case */ + return get_specialization(fn, args, pinfos, info.get_result_dependencies()); + } else { + if (prefix_sz == 0) + return info; + /* Get g : fn + prefix */ + unsigned num_rest_args = pinfos.size() - prefix_sz; + expr g = a; + for (unsigned i = 0; i < num_rest_args; i++) + g = app_fn(g); + if (auto r = m_cache_get_spec.find(mk_pair(num_rest_args, g))) + return *r; + buffer new_pinfos; + copy_prefix(prefix_sz, pinfos, new_pinfos); + auto result_deps = get_core(g, new_pinfos, num_rest_args); + fun_info r(new_pinfos.size(), to_list(new_pinfos), result_deps); + m_cache_get_spec.insert(mk_pair(num_rest_args, g), r); + return r; + } } } diff --git a/src/library/fun_info_manager.h b/src/library/fun_info_manager.h index 5305e9a910..6a4f3a3ab7 100644 --- a/src/library/fun_info_manager.h +++ b/src/library/fun_info_manager.h @@ -11,6 +11,7 @@ Author: Leonardo de Moura namespace lean { /** \brief Function parameter information. It is used by \c fun_info_manager. */ class param_info { + friend class fun_info_manager; /* m_specialized is true if the result of fun_info has been specifialized using this argument. For example, consider the function @@ -50,20 +51,29 @@ public: bool is_prop() const { return m_prop; } bool is_subsingleton() const { return m_subsingleton; } bool is_dep() const { return m_is_dep; } + param_info mk_specialized() const { + param_info r(*this); + r.m_specialized = true; + return r; + } }; /** \brief Function information produced by \c fun_info_manager */ class fun_info { + /* m_specialized is true if the information was produced using the function arguments, + and all m_specialized = true for all m_params_info */ unsigned m_arity; + bool m_specialized; list m_params_info; list m_deps; // resulting type dependencies public: - fun_info():m_arity(0) {} - fun_info(unsigned arity, list const & info, list const & deps): - m_arity(arity), m_params_info(info), m_deps(deps) {} + fun_info():m_arity(0), m_specialized(false) {} + fun_info(unsigned arity, list const & info, list const & deps, bool spec = false): + m_arity(arity), m_specialized(spec), m_params_info(info), m_deps(deps) {} unsigned get_arity() const { return m_arity; } list const & get_params_info() const { return m_params_info; } list const & get_result_dependencies() const { return m_deps; } + bool fully_specialized() const { return m_specialized; } }; /** \brief Helper object for retrieving a summary for the parameters @@ -72,8 +82,23 @@ public: dependencies, implicit binder info, etc. */ class fun_info_manager { type_context & m_ctx; - rb_map m_fun_info; + struct unsigned_expr_cmp { + int operator()(pair const & p1, pair const & p2) const { + if (p1.first != p2.first) + return p1.first < p2.first ? -1 : 1; + else + return expr_quick_cmp()(p1.second, p2.second); + } + }; + typedef rb_map cache; + typedef rb_map, fun_info, unsigned_expr_cmp> narg_cache; + cache m_cache_get; + narg_cache m_cache_get_nargs; + narg_cache m_cache_get_spec; list collect_deps(expr const & e, buffer const & locals); + list get_core(expr const & e, buffer & pinfos, unsigned max_args); + fun_info get_specialization(expr const & fn, buffer const & args, + buffer const & pinfos, list const & result_deps); public: fun_info_manager(type_context & ctx); type_context & ctx() { return m_ctx; }