feat(library/type_context): infer
This commit is contained in:
parent
a937672b0f
commit
b69314a308
2 changed files with 351 additions and 6 deletions
|
|
@ -38,6 +38,10 @@ struct type_context_as_extension_context : public extension_context {
|
|||
}
|
||||
};
|
||||
|
||||
/* =====================
|
||||
type_context_cache
|
||||
===================== */
|
||||
|
||||
type_context_cache::type_context_cache(environment const & env, options const & opts):
|
||||
m_env(env),
|
||||
m_options(opts),
|
||||
|
|
@ -86,6 +90,18 @@ bool type_context_cache::should_unfold_macro(expr const &) {
|
|||
return true;
|
||||
}
|
||||
|
||||
/* =====================
|
||||
type_context::tmp_locals
|
||||
===================== */
|
||||
type_context::tmp_locals::~tmp_locals() {
|
||||
for (unsigned i = 0; i < m_locals.size(); i++)
|
||||
m_ctx.pop_local();
|
||||
}
|
||||
|
||||
/* =====================
|
||||
type_context
|
||||
===================== */
|
||||
|
||||
void type_context::init_core(transparency_mode m) {
|
||||
m_used_assignment = false;
|
||||
m_transparency_mode = m;
|
||||
|
|
@ -165,9 +181,15 @@ optional<expr> type_context::reduce_projection(expr const & e) {
|
|||
return some_expr(r);
|
||||
}
|
||||
|
||||
bool type_context::should_unfold_macro(expr const & e) {
|
||||
/* If m_transparency_mode is set to ALL, then we unfold all
|
||||
macros. In this way, we make sure type inference does not fail. */
|
||||
return m_transparency_mode == transparency_mode::All || m_cache->should_unfold_macro(e);
|
||||
}
|
||||
|
||||
optional<expr> type_context::expand_macro(expr const & e) {
|
||||
lean_assert(is_macro(e));
|
||||
if (m_cache->should_unfold_macro(e)) {
|
||||
if (should_unfold_macro(e)) {
|
||||
type_context_as_extension_context ext(*this);
|
||||
return macro_def(e).expand(e, ext);
|
||||
} else {
|
||||
|
|
@ -251,13 +273,231 @@ optional<expr> type_context::is_stuck(expr const & e) {
|
|||
--------------- */
|
||||
|
||||
expr type_context::infer(expr const & e) {
|
||||
// TODO(Leo)
|
||||
return e;
|
||||
flet<transparency_mode> set(m_transparency_mode, transparency_mode::All);
|
||||
return infer_core(e);
|
||||
}
|
||||
|
||||
expr type_context::infer_core(expr const & e) {
|
||||
lean_assert(!is_var(e));
|
||||
lean_assert(closed(e));
|
||||
|
||||
auto & cache = m_cache->m_infer_cache;
|
||||
auto it = cache.find(e);
|
||||
if (it != cache.end())
|
||||
return it->second;
|
||||
|
||||
reset_used_assignment reset(*this);
|
||||
|
||||
expr r;
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Local:
|
||||
r = infer_local(e);
|
||||
break;
|
||||
case expr_kind::Meta:
|
||||
r = infer_metavar(e);
|
||||
break;
|
||||
case expr_kind::Var:
|
||||
lean_unreachable(); // LCOV_EXCL_LINE
|
||||
case expr_kind::Sort:
|
||||
r = mk_sort(mk_succ(sort_level(e)));
|
||||
break;
|
||||
case expr_kind::Constant:
|
||||
r = infer_constant(e);
|
||||
break;
|
||||
case expr_kind::Macro:
|
||||
r = infer_macro(e);
|
||||
break;
|
||||
case expr_kind::Lambda:
|
||||
r = infer_lambda(e);
|
||||
break;
|
||||
case expr_kind::Pi:
|
||||
r = infer_pi(e);
|
||||
break;
|
||||
case expr_kind::App:
|
||||
r = infer_app(e);
|
||||
break;
|
||||
case expr_kind::Let:
|
||||
r = infer_let(e);
|
||||
break;
|
||||
}
|
||||
|
||||
if (!m_used_assignment)
|
||||
cache.insert(mk_pair(e, r));
|
||||
return r;
|
||||
}
|
||||
|
||||
expr type_context::infer_local(expr const & e) {
|
||||
lean_assert(is_local(e));
|
||||
if (is_local_decl_ref(e)) {
|
||||
auto d = m_lctx.get_local_decl(e);
|
||||
if (!d)
|
||||
throw exception("infer type failed, unknown variable");
|
||||
lean_assert(d);
|
||||
return d->get_type();
|
||||
} else {
|
||||
/* Remark: depending on how we re-organize the parser, we may be able
|
||||
to remove this branch. */
|
||||
return mlocal_type(e);
|
||||
}
|
||||
}
|
||||
|
||||
expr type_context::infer_metavar(expr const & e) {
|
||||
if (is_metavar_decl_ref(e)) {
|
||||
auto d = m_mctx.get_metavar_decl(e);
|
||||
if (!d)
|
||||
throw exception("infer type failed, unknown metavariable");
|
||||
return d->get_type();
|
||||
} else if (m_tmp_mode && is_idx_metavar(e)) {
|
||||
/* tmp metavariables should only occur in tmp_mode */
|
||||
return mlocal_type(e);
|
||||
} else {
|
||||
lean_unreachable();
|
||||
}
|
||||
}
|
||||
|
||||
expr type_context::infer_constant(expr const & e) {
|
||||
declaration d = env().get(const_name(e));
|
||||
auto const & ps = d.get_univ_params();
|
||||
auto const & ls = const_levels(e);
|
||||
if (length(ps) != length(ls))
|
||||
throw exception("infer type failed, incorrect number of universe levels");
|
||||
return instantiate_type_univ_params(d, ls);
|
||||
}
|
||||
|
||||
expr type_context::infer_macro(expr const & e) {
|
||||
auto def = macro_def(e);
|
||||
bool infer_only = true;
|
||||
type_context_as_extension_context ext(*this);
|
||||
return def.check_type(e, ext, infer_only).first;
|
||||
}
|
||||
|
||||
expr type_context::infer_lambda(expr e) {
|
||||
buffer<expr> es, ds;
|
||||
tmp_locals ls(*this);
|
||||
while (is_lambda(e)) {
|
||||
es.push_back(e);
|
||||
ds.push_back(binding_domain(e));
|
||||
expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data());
|
||||
expr l = ls.push_local(binding_name(e), d, binding_info(e));
|
||||
e = binding_body(e);
|
||||
}
|
||||
check_system("infer_type");
|
||||
expr t = infer_core(instantiate_rev(e, ls.size(), ls.data()));
|
||||
expr r = abstract_locals(t, ls.size(), ls.data());
|
||||
unsigned i = es.size();
|
||||
while (i > 0) {
|
||||
--i;
|
||||
r = mk_pi(binding_name(es[i]), ds[i], r, binding_info(es[i]));
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
optional<level> type_context::get_level_core(expr const & A) {
|
||||
expr A_type = whnf(infer_core(A));
|
||||
while (true) {
|
||||
if (is_sort(A_type)) {
|
||||
return some_level(sort_level(A_type));
|
||||
} else if (is_mvar(A_type)) {
|
||||
if (auto v = get_assignment(A_type)) {
|
||||
A_type = *v;
|
||||
} else if (!m_tmp_mode && is_metavar_decl_ref(A_type)) {
|
||||
/* We should only assign A_type IF we are not in tmp mode. */
|
||||
level r = m_mctx.mk_univ_metavar_decl();
|
||||
assign(A_type, mk_sort(r));
|
||||
return some_level(r);
|
||||
} else if (m_tmp_mode && is_idx_metavar(A_type)) {
|
||||
level r = mk_tmp_univ_mvar();
|
||||
assign(A_type, mk_sort(r));
|
||||
return some_level(r);
|
||||
} else {
|
||||
return none_level();
|
||||
}
|
||||
} else {
|
||||
return none_level();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
level type_context::get_level(expr const & A) {
|
||||
if (auto r = get_level_core(A)) {
|
||||
return *r;
|
||||
} else {
|
||||
throw exception("infer type failed, sort expected");
|
||||
}
|
||||
}
|
||||
|
||||
expr type_context::infer_pi(expr e) {
|
||||
tmp_locals ls(*this);
|
||||
buffer<level> us;
|
||||
while (is_pi(e)) {
|
||||
expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data());
|
||||
us.push_back(get_level(d));
|
||||
expr l = ls.push_local(binding_name(e), d, binding_info(e));
|
||||
e = binding_body(e);
|
||||
}
|
||||
e = instantiate_rev(e, ls.size(), ls.data());
|
||||
level r = get_level(e);
|
||||
unsigned i = ls.size();
|
||||
bool imp = env().impredicative();
|
||||
while (i > 0) {
|
||||
--i;
|
||||
r = imp ? mk_imax(us[i], r) : mk_max(us[i], r);
|
||||
}
|
||||
return mk_sort(r);
|
||||
}
|
||||
|
||||
expr type_context::infer_app(expr const & e) {
|
||||
check_system("infer_type");
|
||||
buffer<expr> args;
|
||||
expr const & f = get_app_args(e, args);
|
||||
expr f_type = infer_core(f);
|
||||
unsigned j = 0;
|
||||
unsigned nargs = args.size();
|
||||
for (unsigned i = 0; i < nargs; i++) {
|
||||
if (is_pi(f_type)) {
|
||||
f_type = binding_body(f_type);
|
||||
} else {
|
||||
f_type = whnf(instantiate_rev(f_type, i-j, args.data()+j));
|
||||
if (!is_pi(f_type))
|
||||
throw exception("infer type failed, Pi expected");
|
||||
f_type = binding_body(f_type);
|
||||
j = i;
|
||||
}
|
||||
}
|
||||
return instantiate_rev(f_type, nargs-j, args.data()+j);
|
||||
}
|
||||
|
||||
expr type_context::infer_let(expr e) {
|
||||
/*
|
||||
We may also infer the type of a let-expression by using
|
||||
tmp_locals, push_let, and they closing the resulting type with
|
||||
let-expressions.
|
||||
It is unclear which option is the best / more efficient.
|
||||
The following implementation doesn't need the extra step,
|
||||
but it relies on the cache to avoid repeated work.
|
||||
*/
|
||||
buffer<expr> vs;
|
||||
while (is_let(e)) {
|
||||
expr v = instantiate_rev(let_value(e), vs.size(), vs.data());
|
||||
vs.push_back(v);
|
||||
e = let_body(e);
|
||||
}
|
||||
check_system("infer_type");
|
||||
return infer_core(instantiate_rev(e, vs.size(), vs.data()));
|
||||
}
|
||||
|
||||
expr type_context::check(expr const & e) {
|
||||
// TODO(Leo)
|
||||
return e;
|
||||
// TODO(Leo): infer doesn't really check anything
|
||||
return infer(e);
|
||||
}
|
||||
|
||||
bool type_context::is_prop(expr const & e) {
|
||||
if (env().impredicative()) {
|
||||
expr t = whnf(infer(e));
|
||||
return t == mk_Prop();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/* -------------------------------
|
||||
|
|
@ -300,6 +540,12 @@ void type_context::assign_tmp(expr const & m, expr const & v) {
|
|||
m_tmp_eassignment[to_meta_idx(m)] = v;
|
||||
}
|
||||
|
||||
level type_context::mk_tmp_univ_mvar() {
|
||||
unsigned idx = m_tmp_uassignment.size();
|
||||
m_tmp_uassignment.push_back(none_level());
|
||||
return mk_idx_metauniv(idx);
|
||||
}
|
||||
|
||||
/* -----------------------------------
|
||||
Uniform interface to tmp/regular metavariables
|
||||
----------------------------------- */
|
||||
|
|
@ -456,6 +702,54 @@ optional<declaration> type_context::is_delta(expr const & e) {
|
|||
}
|
||||
|
||||
|
||||
bool type_context::is_def_eq_core(expr const & t, expr const & s) {
|
||||
check_system("is_definitionally_equal");
|
||||
// TODO(Leo)
|
||||
return false;
|
||||
}
|
||||
|
||||
bool type_context::is_def_eq_binding(expr e1, expr e2) {
|
||||
lean_assert(e1.kind() == e2.kind());
|
||||
lean_assert(is_binding(e1));
|
||||
expr_kind k = e1.kind();
|
||||
tmp_locals subst(*this);
|
||||
do {
|
||||
optional<expr> var_e1_type;
|
||||
if (binding_domain(e1) != binding_domain(e2)) {
|
||||
var_e1_type = instantiate_rev(binding_domain(e1), subst.size(), subst.data());
|
||||
expr var_e2_type = instantiate_rev(binding_domain(e2), subst.size(), subst.data());
|
||||
if (!is_def_eq_core(var_e2_type, *var_e1_type))
|
||||
return false;
|
||||
}
|
||||
if (!closed(binding_body(e1)) || !closed(binding_body(e2))) {
|
||||
// local is used inside t or s
|
||||
if (!var_e1_type)
|
||||
var_e1_type = instantiate_rev(binding_domain(e1), subst.size(), subst.data());
|
||||
subst.push_local(binding_name(e1), *var_e1_type);
|
||||
} else {
|
||||
expr const & dont_care = mk_Prop();
|
||||
subst.push_local(binding_name(e1), dont_care);
|
||||
}
|
||||
e1 = binding_body(e1);
|
||||
e2 = binding_body(e2);
|
||||
} while (e1.kind() == k && e2.kind() == k);
|
||||
return is_def_eq_core(instantiate_rev(e1, subst.size(), subst.data()),
|
||||
instantiate_rev(e2, subst.size(), subst.data()));
|
||||
}
|
||||
|
||||
bool type_context::is_def_eq_args(expr const & e1, expr const & e2) {
|
||||
lean_assert(is_app(e1) && is_app(e2));
|
||||
buffer<expr> args1, args2;
|
||||
get_app_args(e1, args1);
|
||||
get_app_args(e2, args2);
|
||||
if (args1.size() != args2.size())
|
||||
return false;
|
||||
for (unsigned i = 0; i < args1.size(); i++) {
|
||||
if (!is_def_eq_core(args1[i], args2[i]))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
struct unification_hint_fn {
|
||||
|
|
|
|||
|
|
@ -207,6 +207,7 @@ public:
|
|||
private:
|
||||
void init_core(transparency_mode m);
|
||||
optional<expr> reduce_projection(expr const & e);
|
||||
bool should_unfold_macro(expr const & e);
|
||||
optional<expr> expand_macro(expr const & e);
|
||||
expr whnf_core(expr const & e);
|
||||
optional<declaration> is_transparent(name const & n);
|
||||
|
|
@ -217,6 +218,8 @@ private:
|
|||
void assign_tmp(level const & u, level const & l);
|
||||
void assign_tmp(expr const & m, expr const & v);
|
||||
|
||||
level mk_tmp_univ_mvar();
|
||||
|
||||
/* ------------
|
||||
Uniform interface to tmp/regular metavariables
|
||||
That is, in tmp mode they access m_tmp_eassignment and m_tmp_uassignment,
|
||||
|
|
@ -232,10 +235,58 @@ public:
|
|||
void assign(level const & u, level const & l);
|
||||
void assign(expr const & m, expr const & v);
|
||||
|
||||
private:
|
||||
/* ------------
|
||||
Type inference
|
||||
------------ */
|
||||
expr infer_core(expr const & e);
|
||||
expr infer_local(expr const & e);
|
||||
expr infer_metavar(expr const & e);
|
||||
expr infer_constant(expr const & e);
|
||||
expr infer_macro(expr const & e);
|
||||
expr infer_lambda(expr e);
|
||||
optional<level> get_level_core(expr const & A);
|
||||
level get_level(expr const & A);
|
||||
expr infer_pi(expr e);
|
||||
expr infer_app(expr const & e);
|
||||
expr infer_let(expr e);
|
||||
|
||||
private:
|
||||
level instantiate(level const & l);
|
||||
expr instantiate(expr const & l);
|
||||
bool is_def_eq(levels const & ls1, levels const & ls2);
|
||||
|
||||
optional<declaration> is_delta(expr const & e);
|
||||
|
||||
bool is_def_eq(levels const & ls1, levels const & ls2);
|
||||
bool is_def_eq_core(expr const & t, expr const & s);
|
||||
bool is_def_eq_binding(expr e1, expr e2);
|
||||
bool is_def_eq_args(expr const & e1, expr const & e2);
|
||||
|
||||
public:
|
||||
/* Helper class for creating pushing local declarations on m_lctx */
|
||||
class tmp_locals {
|
||||
type_context & m_ctx;
|
||||
buffer<expr> m_locals;
|
||||
public:
|
||||
tmp_locals(type_context & ctx):m_ctx(ctx) {}
|
||||
~tmp_locals();
|
||||
|
||||
expr push_local(name const & pp_name, expr const & type, binder_info const & bi = binder_info()) {
|
||||
expr r = m_ctx.push_local(pp_name, type, bi);
|
||||
m_locals.push_back(r);
|
||||
return r;
|
||||
}
|
||||
|
||||
expr push_let(name const & name, expr const & type, expr const & value) {
|
||||
expr r = m_ctx.push_let(name, type, value);
|
||||
m_locals.push_back(r);
|
||||
return r;
|
||||
}
|
||||
|
||||
unsigned size() const { return m_locals.size(); }
|
||||
expr const * data() const { return m_locals.data(); }
|
||||
|
||||
buffer<expr> const & as_buffer() const { return m_locals; }
|
||||
};
|
||||
};
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue