feat(library/compiler/specialize): collect dependencies
This commit is contained in:
parent
4f73cb18bb
commit
777119ceab
1 changed files with 79 additions and 9 deletions
|
|
@ -6,6 +6,7 @@ Author: Leonardo de Moura
|
|||
*/
|
||||
#include "runtime/flet.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "kernel/for_each_fn.h"
|
||||
#include "library/module.h"
|
||||
#include "library/attribute_manager.h"
|
||||
#include "library/compiler/util.h"
|
||||
|
|
@ -43,9 +44,15 @@ static spec_arg_kinds to_spec_arg_kinds(buffer<spec_arg_kind> const & ks) {
|
|||
}
|
||||
return r;
|
||||
}
|
||||
static bool has_fixed_inst_arg(spec_arg_kinds ks) {
|
||||
static void to_buffer(spec_arg_kinds const & ks, buffer<spec_arg_kind> & r) {
|
||||
for (object_ref const & k : ks) {
|
||||
if (to_spec_arg_kind(k) == spec_arg_kind::FixedInst)
|
||||
r.push_back(to_spec_arg_kind(k));
|
||||
}
|
||||
}
|
||||
|
||||
static bool has_fixed_inst_arg(buffer<spec_arg_kind> const & ks) {
|
||||
for (spec_arg_kind k : ks) {
|
||||
if (k == spec_arg_kind::FixedInst)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
@ -316,6 +323,71 @@ class specialize_fn {
|
|||
return e;
|
||||
}
|
||||
|
||||
void collect_deps(expr e, name_set & collected, buffer<expr> & new_params, buffer<expr> & let_vars) {
|
||||
buffer<expr> todo;
|
||||
while (true) {
|
||||
for_each(e, [&](expr const & x, unsigned) {
|
||||
if (!has_fvar(x)) return false;
|
||||
if (is_fvar(x) && !collected.contains(fvar_name(x))) {
|
||||
collected.insert(fvar_name(x));
|
||||
if (optional<expr> v = m_lctx.get_local_decl(x).get_value()) {
|
||||
let_vars.push_back(x);
|
||||
todo.push_back(*v);
|
||||
} else {
|
||||
new_params.push_back(x);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
if (todo.empty())
|
||||
return;
|
||||
e = todo.back();
|
||||
todo.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
expr specialize(expr const & fn, buffer<expr> const & args, names const & mutual, buffer<spec_arg_kind> const & kinds, bool has_attr) {
|
||||
name_set collected;
|
||||
buffer<expr> new_params;
|
||||
buffer<expr> let_vars;
|
||||
unsigned sz = std::min(args.size(), kinds.size());
|
||||
unsigned i = sz;
|
||||
buffer<bool> mask;
|
||||
mask.resize(args.size(), false);
|
||||
bool found_inst = false;
|
||||
while (i > 0) {
|
||||
--i;
|
||||
switch (kinds[i]) {
|
||||
case spec_arg_kind::Other:
|
||||
break;
|
||||
case spec_arg_kind::FixedInst:
|
||||
mask[i] = true;
|
||||
collect_deps(args[i], collected, new_params, let_vars);
|
||||
found_inst = true;
|
||||
break;
|
||||
case spec_arg_kind::FixedHO:
|
||||
case spec_arg_kind::FixedNeutral:
|
||||
case spec_arg_kind::Fixed:
|
||||
if (has_attr || found_inst) {
|
||||
mask[i] = true;
|
||||
collect_deps(args[i], collected, new_params, let_vars);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::sort(new_params.begin(), new_params.end(),
|
||||
[&](expr const & x, expr const & y) { return m_lctx.get_local_decl(x).get_idx() < m_lctx.get_local_decl(y).get_idx(); });
|
||||
std::sort(let_vars.begin(), let_vars.end(),
|
||||
[&](expr const & x, expr const & y) { return m_lctx.get_local_decl(x).get_idx() < m_lctx.get_local_decl(y).get_idx(); });
|
||||
lean_trace(name({"compiler", "specialize"}),
|
||||
tout() << "candidate: " << mk_app(fn, args) << "\nclosure:";
|
||||
for (expr const & p : new_params) tout() << " " << p;
|
||||
for (expr const & x : let_vars) tout() << " " << x;
|
||||
tout() << "\n";);
|
||||
// TODO(Leo):
|
||||
return mk_app(fn, args);
|
||||
}
|
||||
|
||||
expr visit_app(expr const & e) {
|
||||
if (is_cases_on_app(env(), e)) {
|
||||
return visit_cases_on(e);
|
||||
|
|
@ -327,16 +399,16 @@ class specialize_fn {
|
|||
spec_info const * info = ext.m_spec_info.find(const_name(fn));
|
||||
if (!info) return e;
|
||||
bool has_attr = has_specialize_attribute(env(), const_name(fn));
|
||||
spec_arg_kinds kinds = info->get_arg_kinds();
|
||||
buffer<spec_arg_kind> kinds;
|
||||
to_buffer(info->get_arg_kinds(), kinds);
|
||||
if (!has_attr && !has_fixed_inst_arg(kinds))
|
||||
return e; /* Nothing to specialize */
|
||||
type_checker tc(m_st, m_lctx);
|
||||
bool is_candidate = false;
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
if (empty(kinds))
|
||||
if (i >= kinds.size())
|
||||
break;
|
||||
spec_arg_kind k = to_spec_arg_kind(head(kinds));
|
||||
kinds = tail(kinds);
|
||||
spec_arg_kind k = kinds[i];
|
||||
expr w;
|
||||
switch (k) {
|
||||
case spec_arg_kind::FixedNeutral:
|
||||
|
|
@ -374,9 +446,7 @@ class specialize_fn {
|
|||
}
|
||||
if (!is_candidate)
|
||||
return e;
|
||||
lean_trace(name({"compiler", "specialize"}), tout() << "candidate: " << e << "\n";);
|
||||
// TODO(Leo):
|
||||
return e;
|
||||
return specialize(fn, args, info->get_mutual_decls(), kinds, has_attr);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue