feat(library/type_context): infer

This commit is contained in:
Leonardo de Moura 2016-03-10 16:25:54 -08:00
parent a937672b0f
commit b69314a308
2 changed files with 351 additions and 6 deletions

View file

@ -38,6 +38,10 @@ struct type_context_as_extension_context : public extension_context {
}
};
/* =====================
type_context_cache
===================== */
type_context_cache::type_context_cache(environment const & env, options const & opts):
m_env(env),
m_options(opts),
@ -86,6 +90,18 @@ bool type_context_cache::should_unfold_macro(expr const &) {
return true;
}
/* =====================
type_context::tmp_locals
===================== */
type_context::tmp_locals::~tmp_locals() {
for (unsigned i = 0; i < m_locals.size(); i++)
m_ctx.pop_local();
}
/* =====================
type_context
===================== */
void type_context::init_core(transparency_mode m) {
m_used_assignment = false;
m_transparency_mode = m;
@ -165,9 +181,15 @@ optional<expr> type_context::reduce_projection(expr const & e) {
return some_expr(r);
}
bool type_context::should_unfold_macro(expr const & e) {
/* If m_transparency_mode is set to ALL, then we unfold all
macros. In this way, we make sure type inference does not fail. */
return m_transparency_mode == transparency_mode::All || m_cache->should_unfold_macro(e);
}
optional<expr> type_context::expand_macro(expr const & e) {
lean_assert(is_macro(e));
if (m_cache->should_unfold_macro(e)) {
if (should_unfold_macro(e)) {
type_context_as_extension_context ext(*this);
return macro_def(e).expand(e, ext);
} else {
@ -251,13 +273,231 @@ optional<expr> type_context::is_stuck(expr const & e) {
--------------- */
expr type_context::infer(expr const & e) {
// TODO(Leo)
return e;
flet<transparency_mode> set(m_transparency_mode, transparency_mode::All);
return infer_core(e);
}
expr type_context::infer_core(expr const & e) {
lean_assert(!is_var(e));
lean_assert(closed(e));
auto & cache = m_cache->m_infer_cache;
auto it = cache.find(e);
if (it != cache.end())
return it->second;
reset_used_assignment reset(*this);
expr r;
switch (e.kind()) {
case expr_kind::Local:
r = infer_local(e);
break;
case expr_kind::Meta:
r = infer_metavar(e);
break;
case expr_kind::Var:
lean_unreachable(); // LCOV_EXCL_LINE
case expr_kind::Sort:
r = mk_sort(mk_succ(sort_level(e)));
break;
case expr_kind::Constant:
r = infer_constant(e);
break;
case expr_kind::Macro:
r = infer_macro(e);
break;
case expr_kind::Lambda:
r = infer_lambda(e);
break;
case expr_kind::Pi:
r = infer_pi(e);
break;
case expr_kind::App:
r = infer_app(e);
break;
case expr_kind::Let:
r = infer_let(e);
break;
}
if (!m_used_assignment)
cache.insert(mk_pair(e, r));
return r;
}
expr type_context::infer_local(expr const & e) {
lean_assert(is_local(e));
if (is_local_decl_ref(e)) {
auto d = m_lctx.get_local_decl(e);
if (!d)
throw exception("infer type failed, unknown variable");
lean_assert(d);
return d->get_type();
} else {
/* Remark: depending on how we re-organize the parser, we may be able
to remove this branch. */
return mlocal_type(e);
}
}
expr type_context::infer_metavar(expr const & e) {
if (is_metavar_decl_ref(e)) {
auto d = m_mctx.get_metavar_decl(e);
if (!d)
throw exception("infer type failed, unknown metavariable");
return d->get_type();
} else if (m_tmp_mode && is_idx_metavar(e)) {
/* tmp metavariables should only occur in tmp_mode */
return mlocal_type(e);
} else {
lean_unreachable();
}
}
expr type_context::infer_constant(expr const & e) {
declaration d = env().get(const_name(e));
auto const & ps = d.get_univ_params();
auto const & ls = const_levels(e);
if (length(ps) != length(ls))
throw exception("infer type failed, incorrect number of universe levels");
return instantiate_type_univ_params(d, ls);
}
expr type_context::infer_macro(expr const & e) {
auto def = macro_def(e);
bool infer_only = true;
type_context_as_extension_context ext(*this);
return def.check_type(e, ext, infer_only).first;
}
expr type_context::infer_lambda(expr e) {
buffer<expr> es, ds;
tmp_locals ls(*this);
while (is_lambda(e)) {
es.push_back(e);
ds.push_back(binding_domain(e));
expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data());
expr l = ls.push_local(binding_name(e), d, binding_info(e));
e = binding_body(e);
}
check_system("infer_type");
expr t = infer_core(instantiate_rev(e, ls.size(), ls.data()));
expr r = abstract_locals(t, ls.size(), ls.data());
unsigned i = es.size();
while (i > 0) {
--i;
r = mk_pi(binding_name(es[i]), ds[i], r, binding_info(es[i]));
}
return r;
}
optional<level> type_context::get_level_core(expr const & A) {
expr A_type = whnf(infer_core(A));
while (true) {
if (is_sort(A_type)) {
return some_level(sort_level(A_type));
} else if (is_mvar(A_type)) {
if (auto v = get_assignment(A_type)) {
A_type = *v;
} else if (!m_tmp_mode && is_metavar_decl_ref(A_type)) {
/* We should only assign A_type IF we are not in tmp mode. */
level r = m_mctx.mk_univ_metavar_decl();
assign(A_type, mk_sort(r));
return some_level(r);
} else if (m_tmp_mode && is_idx_metavar(A_type)) {
level r = mk_tmp_univ_mvar();
assign(A_type, mk_sort(r));
return some_level(r);
} else {
return none_level();
}
} else {
return none_level();
}
}
}
level type_context::get_level(expr const & A) {
if (auto r = get_level_core(A)) {
return *r;
} else {
throw exception("infer type failed, sort expected");
}
}
expr type_context::infer_pi(expr e) {
tmp_locals ls(*this);
buffer<level> us;
while (is_pi(e)) {
expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data());
us.push_back(get_level(d));
expr l = ls.push_local(binding_name(e), d, binding_info(e));
e = binding_body(e);
}
e = instantiate_rev(e, ls.size(), ls.data());
level r = get_level(e);
unsigned i = ls.size();
bool imp = env().impredicative();
while (i > 0) {
--i;
r = imp ? mk_imax(us[i], r) : mk_max(us[i], r);
}
return mk_sort(r);
}
expr type_context::infer_app(expr const & e) {
check_system("infer_type");
buffer<expr> args;
expr const & f = get_app_args(e, args);
expr f_type = infer_core(f);
unsigned j = 0;
unsigned nargs = args.size();
for (unsigned i = 0; i < nargs; i++) {
if (is_pi(f_type)) {
f_type = binding_body(f_type);
} else {
f_type = whnf(instantiate_rev(f_type, i-j, args.data()+j));
if (!is_pi(f_type))
throw exception("infer type failed, Pi expected");
f_type = binding_body(f_type);
j = i;
}
}
return instantiate_rev(f_type, nargs-j, args.data()+j);
}
expr type_context::infer_let(expr e) {
/*
We may also infer the type of a let-expression by using
tmp_locals, push_let, and they closing the resulting type with
let-expressions.
It is unclear which option is the best / more efficient.
The following implementation doesn't need the extra step,
but it relies on the cache to avoid repeated work.
*/
buffer<expr> vs;
while (is_let(e)) {
expr v = instantiate_rev(let_value(e), vs.size(), vs.data());
vs.push_back(v);
e = let_body(e);
}
check_system("infer_type");
return infer_core(instantiate_rev(e, vs.size(), vs.data()));
}
expr type_context::check(expr const & e) {
// TODO(Leo)
return e;
// TODO(Leo): infer doesn't really check anything
return infer(e);
}
bool type_context::is_prop(expr const & e) {
if (env().impredicative()) {
expr t = whnf(infer(e));
return t == mk_Prop();
} else {
return false;
}
}
/* -------------------------------
@ -300,6 +540,12 @@ void type_context::assign_tmp(expr const & m, expr const & v) {
m_tmp_eassignment[to_meta_idx(m)] = v;
}
level type_context::mk_tmp_univ_mvar() {
unsigned idx = m_tmp_uassignment.size();
m_tmp_uassignment.push_back(none_level());
return mk_idx_metauniv(idx);
}
/* -----------------------------------
Uniform interface to tmp/regular metavariables
----------------------------------- */
@ -456,6 +702,54 @@ optional<declaration> type_context::is_delta(expr const & e) {
}
bool type_context::is_def_eq_core(expr const & t, expr const & s) {
check_system("is_definitionally_equal");
// TODO(Leo)
return false;
}
bool type_context::is_def_eq_binding(expr e1, expr e2) {
lean_assert(e1.kind() == e2.kind());
lean_assert(is_binding(e1));
expr_kind k = e1.kind();
tmp_locals subst(*this);
do {
optional<expr> var_e1_type;
if (binding_domain(e1) != binding_domain(e2)) {
var_e1_type = instantiate_rev(binding_domain(e1), subst.size(), subst.data());
expr var_e2_type = instantiate_rev(binding_domain(e2), subst.size(), subst.data());
if (!is_def_eq_core(var_e2_type, *var_e1_type))
return false;
}
if (!closed(binding_body(e1)) || !closed(binding_body(e2))) {
// local is used inside t or s
if (!var_e1_type)
var_e1_type = instantiate_rev(binding_domain(e1), subst.size(), subst.data());
subst.push_local(binding_name(e1), *var_e1_type);
} else {
expr const & dont_care = mk_Prop();
subst.push_local(binding_name(e1), dont_care);
}
e1 = binding_body(e1);
e2 = binding_body(e2);
} while (e1.kind() == k && e2.kind() == k);
return is_def_eq_core(instantiate_rev(e1, subst.size(), subst.data()),
instantiate_rev(e2, subst.size(), subst.data()));
}
bool type_context::is_def_eq_args(expr const & e1, expr const & e2) {
lean_assert(is_app(e1) && is_app(e2));
buffer<expr> args1, args2;
get_app_args(e1, args1);
get_app_args(e2, args2);
if (args1.size() != args2.size())
return false;
for (unsigned i = 0; i < args1.size(); i++) {
if (!is_def_eq_core(args1[i], args2[i]))
return false;
}
return true;
}
/*
struct unification_hint_fn {

View file

@ -207,6 +207,7 @@ public:
private:
void init_core(transparency_mode m);
optional<expr> reduce_projection(expr const & e);
bool should_unfold_macro(expr const & e);
optional<expr> expand_macro(expr const & e);
expr whnf_core(expr const & e);
optional<declaration> is_transparent(name const & n);
@ -217,6 +218,8 @@ private:
void assign_tmp(level const & u, level const & l);
void assign_tmp(expr const & m, expr const & v);
level mk_tmp_univ_mvar();
/* ------------
Uniform interface to tmp/regular metavariables
That is, in tmp mode they access m_tmp_eassignment and m_tmp_uassignment,
@ -232,10 +235,58 @@ public:
void assign(level const & u, level const & l);
void assign(expr const & m, expr const & v);
private:
/* ------------
Type inference
------------ */
expr infer_core(expr const & e);
expr infer_local(expr const & e);
expr infer_metavar(expr const & e);
expr infer_constant(expr const & e);
expr infer_macro(expr const & e);
expr infer_lambda(expr e);
optional<level> get_level_core(expr const & A);
level get_level(expr const & A);
expr infer_pi(expr e);
expr infer_app(expr const & e);
expr infer_let(expr e);
private:
level instantiate(level const & l);
expr instantiate(expr const & l);
bool is_def_eq(levels const & ls1, levels const & ls2);
optional<declaration> is_delta(expr const & e);
bool is_def_eq(levels const & ls1, levels const & ls2);
bool is_def_eq_core(expr const & t, expr const & s);
bool is_def_eq_binding(expr e1, expr e2);
bool is_def_eq_args(expr const & e1, expr const & e2);
public:
/* Helper class for creating pushing local declarations on m_lctx */
class tmp_locals {
type_context & m_ctx;
buffer<expr> m_locals;
public:
tmp_locals(type_context & ctx):m_ctx(ctx) {}
~tmp_locals();
expr push_local(name const & pp_name, expr const & type, binder_info const & bi = binder_info()) {
expr r = m_ctx.push_local(pp_name, type, bi);
m_locals.push_back(r);
return r;
}
expr push_let(name const & name, expr const & type, expr const & value) {
expr r = m_ctx.push_let(name, type, value);
m_locals.push_back(r);
return r;
}
unsigned size() const { return m_locals.size(); }
expr const * data() const { return m_locals.data(); }
buffer<expr> const & as_buffer() const { return m_locals; }
};
};
}