From 777119ceabe33ac06f731a382a0fdee9d2ca3969 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 15 Oct 2018 17:15:42 -0700 Subject: [PATCH] feat(library/compiler/specialize): collect dependencies --- src/library/compiler/specialize.cpp | 88 ++++++++++++++++++++++++++--- 1 file changed, 79 insertions(+), 9 deletions(-) diff --git a/src/library/compiler/specialize.cpp b/src/library/compiler/specialize.cpp index ec0966c1e0..716bbb710a 100644 --- a/src/library/compiler/specialize.cpp +++ b/src/library/compiler/specialize.cpp @@ -6,6 +6,7 @@ Author: Leonardo de Moura */ #include "runtime/flet.h" #include "kernel/instantiate.h" +#include "kernel/for_each_fn.h" #include "library/module.h" #include "library/attribute_manager.h" #include "library/compiler/util.h" @@ -43,9 +44,15 @@ static spec_arg_kinds to_spec_arg_kinds(buffer const & ks) { } return r; } -static bool has_fixed_inst_arg(spec_arg_kinds ks) { +static void to_buffer(spec_arg_kinds const & ks, buffer & r) { for (object_ref const & k : ks) { - if (to_spec_arg_kind(k) == spec_arg_kind::FixedInst) + r.push_back(to_spec_arg_kind(k)); + } +} + +static bool has_fixed_inst_arg(buffer const & ks) { + for (spec_arg_kind k : ks) { + if (k == spec_arg_kind::FixedInst) return true; } return false; @@ -316,6 +323,71 @@ class specialize_fn { return e; } + void collect_deps(expr e, name_set & collected, buffer & new_params, buffer & let_vars) { + buffer todo; + while (true) { + for_each(e, [&](expr const & x, unsigned) { + if (!has_fvar(x)) return false; + if (is_fvar(x) && !collected.contains(fvar_name(x))) { + collected.insert(fvar_name(x)); + if (optional v = m_lctx.get_local_decl(x).get_value()) { + let_vars.push_back(x); + todo.push_back(*v); + } else { + new_params.push_back(x); + } + } + return true; + }); + if (todo.empty()) + return; + e = todo.back(); + todo.pop_back(); + } + } + + expr specialize(expr const & fn, buffer const & args, names const & mutual, buffer const & kinds, bool has_attr) { + name_set collected; + buffer new_params; + buffer let_vars; + unsigned sz = std::min(args.size(), kinds.size()); + unsigned i = sz; + buffer mask; + mask.resize(args.size(), false); + bool found_inst = false; + while (i > 0) { + --i; + switch (kinds[i]) { + case spec_arg_kind::Other: + break; + case spec_arg_kind::FixedInst: + mask[i] = true; + collect_deps(args[i], collected, new_params, let_vars); + found_inst = true; + break; + case spec_arg_kind::FixedHO: + case spec_arg_kind::FixedNeutral: + case spec_arg_kind::Fixed: + if (has_attr || found_inst) { + mask[i] = true; + collect_deps(args[i], collected, new_params, let_vars); + } + break; + } + } + std::sort(new_params.begin(), new_params.end(), + [&](expr const & x, expr const & y) { return m_lctx.get_local_decl(x).get_idx() < m_lctx.get_local_decl(y).get_idx(); }); + std::sort(let_vars.begin(), let_vars.end(), + [&](expr const & x, expr const & y) { return m_lctx.get_local_decl(x).get_idx() < m_lctx.get_local_decl(y).get_idx(); }); + lean_trace(name({"compiler", "specialize"}), + tout() << "candidate: " << mk_app(fn, args) << "\nclosure:"; + for (expr const & p : new_params) tout() << " " << p; + for (expr const & x : let_vars) tout() << " " << x; + tout() << "\n";); + // TODO(Leo): + return mk_app(fn, args); + } + expr visit_app(expr const & e) { if (is_cases_on_app(env(), e)) { return visit_cases_on(e); @@ -327,16 +399,16 @@ class specialize_fn { spec_info const * info = ext.m_spec_info.find(const_name(fn)); if (!info) return e; bool has_attr = has_specialize_attribute(env(), const_name(fn)); - spec_arg_kinds kinds = info->get_arg_kinds(); + buffer kinds; + to_buffer(info->get_arg_kinds(), kinds); if (!has_attr && !has_fixed_inst_arg(kinds)) return e; /* Nothing to specialize */ type_checker tc(m_st, m_lctx); bool is_candidate = false; for (unsigned i = 0; i < args.size(); i++) { - if (empty(kinds)) + if (i >= kinds.size()) break; - spec_arg_kind k = to_spec_arg_kind(head(kinds)); - kinds = tail(kinds); + spec_arg_kind k = kinds[i]; expr w; switch (k) { case spec_arg_kind::FixedNeutral: @@ -374,9 +446,7 @@ class specialize_fn { } if (!is_candidate) return e; - lean_trace(name({"compiler", "specialize"}), tout() << "candidate: " << e << "\n";); - // TODO(Leo): - return e; + return specialize(fn, args, info->get_mutual_decls(), kinds, has_attr); } }