feat(library/compiler/specialize): collect dependencies

This commit is contained in:
Leonardo de Moura 2018-10-15 17:15:42 -07:00
parent 4f73cb18bb
commit 777119ceab

View file

@ -6,6 +6,7 @@ Author: Leonardo de Moura
*/
#include "runtime/flet.h"
#include "kernel/instantiate.h"
#include "kernel/for_each_fn.h"
#include "library/module.h"
#include "library/attribute_manager.h"
#include "library/compiler/util.h"
@ -43,9 +44,15 @@ static spec_arg_kinds to_spec_arg_kinds(buffer<spec_arg_kind> const & ks) {
}
return r;
}
static bool has_fixed_inst_arg(spec_arg_kinds ks) {
static void to_buffer(spec_arg_kinds const & ks, buffer<spec_arg_kind> & r) {
for (object_ref const & k : ks) {
if (to_spec_arg_kind(k) == spec_arg_kind::FixedInst)
r.push_back(to_spec_arg_kind(k));
}
}
static bool has_fixed_inst_arg(buffer<spec_arg_kind> const & ks) {
for (spec_arg_kind k : ks) {
if (k == spec_arg_kind::FixedInst)
return true;
}
return false;
@ -316,6 +323,71 @@ class specialize_fn {
return e;
}
void collect_deps(expr e, name_set & collected, buffer<expr> & new_params, buffer<expr> & let_vars) {
buffer<expr> todo;
while (true) {
for_each(e, [&](expr const & x, unsigned) {
if (!has_fvar(x)) return false;
if (is_fvar(x) && !collected.contains(fvar_name(x))) {
collected.insert(fvar_name(x));
if (optional<expr> v = m_lctx.get_local_decl(x).get_value()) {
let_vars.push_back(x);
todo.push_back(*v);
} else {
new_params.push_back(x);
}
}
return true;
});
if (todo.empty())
return;
e = todo.back();
todo.pop_back();
}
}
expr specialize(expr const & fn, buffer<expr> const & args, names const & mutual, buffer<spec_arg_kind> const & kinds, bool has_attr) {
name_set collected;
buffer<expr> new_params;
buffer<expr> let_vars;
unsigned sz = std::min(args.size(), kinds.size());
unsigned i = sz;
buffer<bool> mask;
mask.resize(args.size(), false);
bool found_inst = false;
while (i > 0) {
--i;
switch (kinds[i]) {
case spec_arg_kind::Other:
break;
case spec_arg_kind::FixedInst:
mask[i] = true;
collect_deps(args[i], collected, new_params, let_vars);
found_inst = true;
break;
case spec_arg_kind::FixedHO:
case spec_arg_kind::FixedNeutral:
case spec_arg_kind::Fixed:
if (has_attr || found_inst) {
mask[i] = true;
collect_deps(args[i], collected, new_params, let_vars);
}
break;
}
}
std::sort(new_params.begin(), new_params.end(),
[&](expr const & x, expr const & y) { return m_lctx.get_local_decl(x).get_idx() < m_lctx.get_local_decl(y).get_idx(); });
std::sort(let_vars.begin(), let_vars.end(),
[&](expr const & x, expr const & y) { return m_lctx.get_local_decl(x).get_idx() < m_lctx.get_local_decl(y).get_idx(); });
lean_trace(name({"compiler", "specialize"}),
tout() << "candidate: " << mk_app(fn, args) << "\nclosure:";
for (expr const & p : new_params) tout() << " " << p;
for (expr const & x : let_vars) tout() << " " << x;
tout() << "\n";);
// TODO(Leo):
return mk_app(fn, args);
}
expr visit_app(expr const & e) {
if (is_cases_on_app(env(), e)) {
return visit_cases_on(e);
@ -327,16 +399,16 @@ class specialize_fn {
spec_info const * info = ext.m_spec_info.find(const_name(fn));
if (!info) return e;
bool has_attr = has_specialize_attribute(env(), const_name(fn));
spec_arg_kinds kinds = info->get_arg_kinds();
buffer<spec_arg_kind> kinds;
to_buffer(info->get_arg_kinds(), kinds);
if (!has_attr && !has_fixed_inst_arg(kinds))
return e; /* Nothing to specialize */
type_checker tc(m_st, m_lctx);
bool is_candidate = false;
for (unsigned i = 0; i < args.size(); i++) {
if (empty(kinds))
if (i >= kinds.size())
break;
spec_arg_kind k = to_spec_arg_kind(head(kinds));
kinds = tail(kinds);
spec_arg_kind k = kinds[i];
expr w;
switch (k) {
case spec_arg_kind::FixedNeutral:
@ -374,9 +446,7 @@ class specialize_fn {
}
if (!is_candidate)
return e;
lean_trace(name({"compiler", "specialize"}), tout() << "candidate: " << e << "\n";);
// TODO(Leo):
return e;
return specialize(fn, args, info->get_mutual_decls(), kinds, has_attr);
}
}