/* Copyright (c) 2016 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include #include #include "kernel/for_each_fn.h" #include "kernel/instantiate.h" #include "kernel/abstract.h" #include "kernel/expr_maps.h" #include "library/trace.h" #include "library/util.h" #include "library/constants.h" #include "library/expr_unsigned_map.h" #include "library/fun_info.h" #include "library/cache_helper.h" namespace lean { static name * g_fun_info = nullptr; #define lean_trace_fun_info(Code) lean_trace(*g_fun_info, Code) static bool is_fun_info_trace_enabled() { return is_trace_class_enabled(*g_fun_info); } void initialize_fun_info() { g_fun_info = new name("fun_info"); register_trace_class(*g_fun_info); } void finalize_fun_info() { delete g_fun_info; } /* Return the list of dependencies for the given type. If it depends on local i, then it also sets pinfos[i].set_has_fwd_deps(); */ static list collect_deps(expr const & type, buffer const & locals, buffer & pinfos) { buffer deps; for_each(type, [&](expr const & e, unsigned) { if (is_local(e)) { unsigned idx; for (idx = 0; idx < locals.size(); idx++) if (locals[idx] == e) break; if (idx < locals.size() && std::find(deps.begin(), deps.end(), idx) == deps.end()) { deps.push_back(idx); pinfos[idx].set_has_fwd_deps(); } } return has_local(e); // continue the search only if e has locals }); std::sort(deps.begin(), deps.end()); return to_list(deps); } /* Store parameter info for fn in \c pinfos and return the dependencies of the resulting type. */ static list get_core(type_context_old & ctx, expr const & fn, buffer & pinfos, unsigned max_args) { expr type = ctx.relaxed_try_to_pi(ctx.infer(fn)); type_context_old::tmp_locals locals(ctx); unsigned i = 0; while (is_pi(type)) { if (i == max_args) break; expr local_type = consume_auto_opt_param(binding_domain(type)); expr local = locals.push_local(binding_name(type), local_type, binding_info(type)); expr new_type = ctx.relaxed_try_to_pi(instantiate(binding_body(type), local)); bool is_prop = ctx.is_prop(local_type); bool is_dep = false; /* it is set by collect_deps */ binder_info bi = binding_info(type); pinfos.emplace_back(is_implicit(bi), is_inst_implicit(bi), is_prop, is_dep, collect_deps(local_type, locals.as_buffer(), pinfos)); type = new_type; i++; } return collect_deps(type, locals.as_buffer(), pinfos); } fun_info get_fun_info(type_context_old & ctx, expr const & e) { abstract_context_cache & cache = ctx.get_cache(); if (auto r = cache.get_fun_info(ctx.mode(), e)) return *r; buffer pinfos; auto result_deps = get_core(ctx, e, pinfos, std::numeric_limits::max()); fun_info r(pinfos.size(), to_list(pinfos), result_deps); cache.set_fun_info(ctx.mode(), e, r); return r; } fun_info get_fun_info(type_context_old & ctx, expr const & e, unsigned nargs) { abstract_context_cache & cache = ctx.get_cache(); if (auto r = cache.get_fun_info_nargs(ctx.mode(), e, nargs)) return *r; buffer pinfos; auto result_deps = get_core(ctx, e, pinfos, nargs); fun_info r(pinfos.size(), to_list(pinfos), result_deps); cache.set_fun_info_nargs(ctx.mode(), e, nargs, r); return r; } /* Store subsingleton parameter info for fn in \c ssinfos */ static void get_ss_core(type_context_old & ctx, expr const & fn, buffer & ssinfos, unsigned max_args) { expr type = ctx.relaxed_try_to_pi(ctx.infer(fn)); type_context_old::tmp_locals locals(ctx); unsigned i = 0; while (is_pi(type)) { if (i == max_args) break; expr local = locals.push_local_from_binding(type); expr local_type = ctx.infer(local); expr new_type = ctx.relaxed_try_to_pi(instantiate(binding_body(type), local)); bool spec = false; bool is_prop = ctx.is_prop(local_type); bool is_sub = is_prop; if (!is_sub) { // TODO(Leo): check if the following line is a performance bottleneck. is_sub = static_cast(ctx.mk_subsingleton_instance(local_type)); } ssinfos.emplace_back(spec, is_sub); type = new_type; i++; } } ss_param_infos get_subsingleton_info(type_context_old & ctx, expr const & e) { abstract_context_cache & cache = ctx.get_cache(); if (auto r = cache.get_subsingleton_info(ctx.mode(), e)) return *r; buffer ssinfos; get_ss_core(ctx, e, ssinfos, std::numeric_limits::max()); ss_param_infos r = to_list(ssinfos); cache.set_subsingleton_info(ctx.mode(), e, r); return r; } ss_param_infos get_subsingleton_info(type_context_old & ctx, expr const & e, unsigned nargs) { abstract_context_cache & cache = ctx.get_cache(); if (auto r = cache.get_subsingleton_info_nargs(ctx.mode(), e, nargs)) return *r; buffer ssinfos; get_ss_core(ctx, e, ssinfos, nargs); ss_param_infos r = to_list(ssinfos); cache.set_subsingleton_info_nargs(ctx.mode(), e, nargs, r); return r; } /* Return true if there is j s.t. ssinfos[j] is marked as subsingleton, and it dependends of argument i */ static bool has_nonsubsingleton_fwd_dep(unsigned i, buffer const & pinfos, buffer const & ssinfos) { lean_assert(pinfos.size() == ssinfos.size()); for (unsigned j = i+1; j < pinfos.size(); j++) { if (ssinfos[j].is_subsingleton()) continue; auto const & back_deps = pinfos[j].get_back_deps(); if (std::find(back_deps.begin(), back_deps.end(), i) != back_deps.end()) { return true; } } return false; } static void trace_if_unsupported(type_context_old & ctx, expr const & fn, buffer const & args, unsigned prefix_sz, ss_param_infos const & result) { lean_assert(args.size() >= length(result)); if (!is_fun_info_trace_enabled()) return; fun_info info = get_fun_info(ctx, fn, args.size()); buffer pinfos; to_buffer(info.get_params_info(), pinfos); buffer ssinfos; to_buffer(get_subsingleton_info(ctx, fn, args.size()), ssinfos); lean_assert(pinfos.size() == ssinfos.size()); /* Check if all remaining arguments are nondependent or dependent (but all forward dependencies are subsingletons) */ unsigned i = prefix_sz; for (; i < pinfos.size(); i++) { param_info const & pinfo = pinfos[i]; if (!pinfo.has_fwd_deps()) continue; /* nondependent argument */ if (has_nonsubsingleton_fwd_dep(i, pinfos, ssinfos)) break; /* failed i-th argument has a forward dependent that is not a prop nor a subsingleton */ } if (i == pinfos.size()) return; // It is *cheap* case /* Expensive case */ /* We generate a trace message IF it would be possible to compute more precise information. That is, there is an argument that is a proposition and/or subsingleton, but the corresponding pinfo is not a marked a prop/subsingleton. */ i = 0; for (ss_param_info const & ssinfo : result) { if (ssinfo.is_subsingleton()) continue; expr arg_type = ctx.infer(args[i]); if (ctx.mk_subsingleton_instance(arg_type)) { lean_trace_fun_info( tout() << "approximating function information for '" << fn << "', this may affect the effectiveness of the simplifier and congruence closure modules, " << "more precise information can be efficiently computed if all parameters are moved to the " << "beginning of the function\n";); return; } i++; } } unsigned get_specialization_prefix_size(type_context_old & ctx, expr const & fn, unsigned nargs) { /* We say a function is "cheap" if it is of the form: a) 0 or more dependent parameters p s.t. there is at least one forward dependency x : C[p] which is not a proposition nor a subsingleton. b) followed by 0 or more nondependent parameter and/or a dependent parameter s.t. all forward dependencies are propositions and subsingletons. We have a caching mechanism for the "cheap" case. The cheap case cover many commonly used functions eq : Pi {A : Type} (x y : A), Prop add : Pi {A : Type} [s : has_add A] (x y : A), A inv : Pi {A : Type} [s : has_inv A] (x : A) (h : invertible x), A but it doesn't cover p : Pi {A : Type} (x : A) {B : Type} (y : B), Prop I don't think this is a big deal since we can write it as: p : Pi {A : Type} {B : Type} (x : A) (y : B), Prop Therefore, we ignore the non-cheap cases, and pretend they are "cheap". If tracing is enabled, we produce a tracing message whenever we find a non-cheap case. This procecure returns the size of group a) */ abstract_context_cache & cache = ctx.get_cache(); if (auto r = cache.get_specialization_prefix_size(ctx.mode(), fn, nargs)) return *r; fun_info info = get_fun_info(ctx, fn, nargs); buffer pinfos; to_buffer(info.get_params_info(), pinfos); buffer ssinfos; to_buffer(get_subsingleton_info(ctx, fn, nargs), ssinfos); lean_assert(pinfos.size() == ssinfos.size()); /* Compute "prefix": 0 or more parameters s.t. at lest one forward dependency is not a proposition or a subsingleton */ unsigned i = 0; for (; i < pinfos.size(); i++) { param_info const & pinfo = pinfos[i]; if (!pinfo.has_fwd_deps()) break; /* search for forward dependency that is not a proposition nor a subsingleton */ if (!has_nonsubsingleton_fwd_dep(i, pinfos, ssinfos)) break; } unsigned prefix_sz = i; cache.set_specialization_prefix_size(ctx.mode(), fn, nargs, prefix_sz); return prefix_sz; } ss_param_infos get_specialized_subsingleton_info(type_context_old & ctx, expr const & a) { lean_assert(is_app(a)); buffer args; expr const & fn = get_app_args(a, args); unsigned prefix_sz = get_specialization_prefix_size(ctx, fn, args.size()); unsigned num_rest_args = args.size() - prefix_sz; expr g = a; for (unsigned i = 0; i < num_rest_args; i++) g = app_fn(g); abstract_context_cache & cache = ctx.get_cache(); if (auto r = cache.get_specialized_subsingleton_info_nargs(ctx.mode(), g, num_rest_args)) return *r; buffer ssinfos; get_ss_core(ctx, fn, ssinfos, prefix_sz); for (unsigned i = 0; i < prefix_sz; i++) { ssinfos[i].set_specialized(); } get_ss_core(ctx, g, ssinfos, num_rest_args); ss_param_infos r = to_list(ssinfos); cache.set_specialization_subsingleton_info_nargs(ctx.mode(), g, num_rest_args, r); trace_if_unsupported(ctx, fn, args, prefix_sz, r); return r; } }