From b69314a308ccdc564deaa84c62081e428e0f33dc Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 10 Mar 2016 16:25:54 -0800 Subject: [PATCH] feat(library/type_context): infer --- src/library/type_context.cpp | 304 ++++++++++++++++++++++++++++++++++- src/library/type_context.h | 53 +++++- 2 files changed, 351 insertions(+), 6 deletions(-) diff --git a/src/library/type_context.cpp b/src/library/type_context.cpp index 80bcaf30f4..81b4dba4ad 100644 --- a/src/library/type_context.cpp +++ b/src/library/type_context.cpp @@ -38,6 +38,10 @@ struct type_context_as_extension_context : public extension_context { } }; +/* ===================== + type_context_cache + ===================== */ + type_context_cache::type_context_cache(environment const & env, options const & opts): m_env(env), m_options(opts), @@ -86,6 +90,18 @@ bool type_context_cache::should_unfold_macro(expr const &) { return true; } +/* ===================== + type_context::tmp_locals + ===================== */ +type_context::tmp_locals::~tmp_locals() { + for (unsigned i = 0; i < m_locals.size(); i++) + m_ctx.pop_local(); +} + +/* ===================== + type_context + ===================== */ + void type_context::init_core(transparency_mode m) { m_used_assignment = false; m_transparency_mode = m; @@ -165,9 +181,15 @@ optional type_context::reduce_projection(expr const & e) { return some_expr(r); } +bool type_context::should_unfold_macro(expr const & e) { + /* If m_transparency_mode is set to ALL, then we unfold all + macros. In this way, we make sure type inference does not fail. */ + return m_transparency_mode == transparency_mode::All || m_cache->should_unfold_macro(e); +} + optional type_context::expand_macro(expr const & e) { lean_assert(is_macro(e)); - if (m_cache->should_unfold_macro(e)) { + if (should_unfold_macro(e)) { type_context_as_extension_context ext(*this); return macro_def(e).expand(e, ext); } else { @@ -251,13 +273,231 @@ optional type_context::is_stuck(expr const & e) { --------------- */ expr type_context::infer(expr const & e) { - // TODO(Leo) - return e; + flet set(m_transparency_mode, transparency_mode::All); + return infer_core(e); +} + +expr type_context::infer_core(expr const & e) { + lean_assert(!is_var(e)); + lean_assert(closed(e)); + + auto & cache = m_cache->m_infer_cache; + auto it = cache.find(e); + if (it != cache.end()) + return it->second; + + reset_used_assignment reset(*this); + + expr r; + switch (e.kind()) { + case expr_kind::Local: + r = infer_local(e); + break; + case expr_kind::Meta: + r = infer_metavar(e); + break; + case expr_kind::Var: + lean_unreachable(); // LCOV_EXCL_LINE + case expr_kind::Sort: + r = mk_sort(mk_succ(sort_level(e))); + break; + case expr_kind::Constant: + r = infer_constant(e); + break; + case expr_kind::Macro: + r = infer_macro(e); + break; + case expr_kind::Lambda: + r = infer_lambda(e); + break; + case expr_kind::Pi: + r = infer_pi(e); + break; + case expr_kind::App: + r = infer_app(e); + break; + case expr_kind::Let: + r = infer_let(e); + break; + } + + if (!m_used_assignment) + cache.insert(mk_pair(e, r)); + return r; +} + +expr type_context::infer_local(expr const & e) { + lean_assert(is_local(e)); + if (is_local_decl_ref(e)) { + auto d = m_lctx.get_local_decl(e); + if (!d) + throw exception("infer type failed, unknown variable"); + lean_assert(d); + return d->get_type(); + } else { + /* Remark: depending on how we re-organize the parser, we may be able + to remove this branch. */ + return mlocal_type(e); + } +} + +expr type_context::infer_metavar(expr const & e) { + if (is_metavar_decl_ref(e)) { + auto d = m_mctx.get_metavar_decl(e); + if (!d) + throw exception("infer type failed, unknown metavariable"); + return d->get_type(); + } else if (m_tmp_mode && is_idx_metavar(e)) { + /* tmp metavariables should only occur in tmp_mode */ + return mlocal_type(e); + } else { + lean_unreachable(); + } +} + +expr type_context::infer_constant(expr const & e) { + declaration d = env().get(const_name(e)); + auto const & ps = d.get_univ_params(); + auto const & ls = const_levels(e); + if (length(ps) != length(ls)) + throw exception("infer type failed, incorrect number of universe levels"); + return instantiate_type_univ_params(d, ls); +} + +expr type_context::infer_macro(expr const & e) { + auto def = macro_def(e); + bool infer_only = true; + type_context_as_extension_context ext(*this); + return def.check_type(e, ext, infer_only).first; +} + +expr type_context::infer_lambda(expr e) { + buffer es, ds; + tmp_locals ls(*this); + while (is_lambda(e)) { + es.push_back(e); + ds.push_back(binding_domain(e)); + expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data()); + expr l = ls.push_local(binding_name(e), d, binding_info(e)); + e = binding_body(e); + } + check_system("infer_type"); + expr t = infer_core(instantiate_rev(e, ls.size(), ls.data())); + expr r = abstract_locals(t, ls.size(), ls.data()); + unsigned i = es.size(); + while (i > 0) { + --i; + r = mk_pi(binding_name(es[i]), ds[i], r, binding_info(es[i])); + } + return r; +} + +optional type_context::get_level_core(expr const & A) { + expr A_type = whnf(infer_core(A)); + while (true) { + if (is_sort(A_type)) { + return some_level(sort_level(A_type)); + } else if (is_mvar(A_type)) { + if (auto v = get_assignment(A_type)) { + A_type = *v; + } else if (!m_tmp_mode && is_metavar_decl_ref(A_type)) { + /* We should only assign A_type IF we are not in tmp mode. */ + level r = m_mctx.mk_univ_metavar_decl(); + assign(A_type, mk_sort(r)); + return some_level(r); + } else if (m_tmp_mode && is_idx_metavar(A_type)) { + level r = mk_tmp_univ_mvar(); + assign(A_type, mk_sort(r)); + return some_level(r); + } else { + return none_level(); + } + } else { + return none_level(); + } + } +} + +level type_context::get_level(expr const & A) { + if (auto r = get_level_core(A)) { + return *r; + } else { + throw exception("infer type failed, sort expected"); + } +} + +expr type_context::infer_pi(expr e) { + tmp_locals ls(*this); + buffer us; + while (is_pi(e)) { + expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data()); + us.push_back(get_level(d)); + expr l = ls.push_local(binding_name(e), d, binding_info(e)); + e = binding_body(e); + } + e = instantiate_rev(e, ls.size(), ls.data()); + level r = get_level(e); + unsigned i = ls.size(); + bool imp = env().impredicative(); + while (i > 0) { + --i; + r = imp ? mk_imax(us[i], r) : mk_max(us[i], r); + } + return mk_sort(r); +} + +expr type_context::infer_app(expr const & e) { + check_system("infer_type"); + buffer args; + expr const & f = get_app_args(e, args); + expr f_type = infer_core(f); + unsigned j = 0; + unsigned nargs = args.size(); + for (unsigned i = 0; i < nargs; i++) { + if (is_pi(f_type)) { + f_type = binding_body(f_type); + } else { + f_type = whnf(instantiate_rev(f_type, i-j, args.data()+j)); + if (!is_pi(f_type)) + throw exception("infer type failed, Pi expected"); + f_type = binding_body(f_type); + j = i; + } + } + return instantiate_rev(f_type, nargs-j, args.data()+j); +} + +expr type_context::infer_let(expr e) { + /* + We may also infer the type of a let-expression by using + tmp_locals, push_let, and they closing the resulting type with + let-expressions. + It is unclear which option is the best / more efficient. + The following implementation doesn't need the extra step, + but it relies on the cache to avoid repeated work. + */ + buffer vs; + while (is_let(e)) { + expr v = instantiate_rev(let_value(e), vs.size(), vs.data()); + vs.push_back(v); + e = let_body(e); + } + check_system("infer_type"); + return infer_core(instantiate_rev(e, vs.size(), vs.data())); } expr type_context::check(expr const & e) { - // TODO(Leo) - return e; + // TODO(Leo): infer doesn't really check anything + return infer(e); +} + +bool type_context::is_prop(expr const & e) { + if (env().impredicative()) { + expr t = whnf(infer(e)); + return t == mk_Prop(); + } else { + return false; + } } /* ------------------------------- @@ -300,6 +540,12 @@ void type_context::assign_tmp(expr const & m, expr const & v) { m_tmp_eassignment[to_meta_idx(m)] = v; } +level type_context::mk_tmp_univ_mvar() { + unsigned idx = m_tmp_uassignment.size(); + m_tmp_uassignment.push_back(none_level()); + return mk_idx_metauniv(idx); +} + /* ----------------------------------- Uniform interface to tmp/regular metavariables ----------------------------------- */ @@ -456,6 +702,54 @@ optional type_context::is_delta(expr const & e) { } +bool type_context::is_def_eq_core(expr const & t, expr const & s) { + check_system("is_definitionally_equal"); + // TODO(Leo) + return false; +} + +bool type_context::is_def_eq_binding(expr e1, expr e2) { + lean_assert(e1.kind() == e2.kind()); + lean_assert(is_binding(e1)); + expr_kind k = e1.kind(); + tmp_locals subst(*this); + do { + optional var_e1_type; + if (binding_domain(e1) != binding_domain(e2)) { + var_e1_type = instantiate_rev(binding_domain(e1), subst.size(), subst.data()); + expr var_e2_type = instantiate_rev(binding_domain(e2), subst.size(), subst.data()); + if (!is_def_eq_core(var_e2_type, *var_e1_type)) + return false; + } + if (!closed(binding_body(e1)) || !closed(binding_body(e2))) { + // local is used inside t or s + if (!var_e1_type) + var_e1_type = instantiate_rev(binding_domain(e1), subst.size(), subst.data()); + subst.push_local(binding_name(e1), *var_e1_type); + } else { + expr const & dont_care = mk_Prop(); + subst.push_local(binding_name(e1), dont_care); + } + e1 = binding_body(e1); + e2 = binding_body(e2); + } while (e1.kind() == k && e2.kind() == k); + return is_def_eq_core(instantiate_rev(e1, subst.size(), subst.data()), + instantiate_rev(e2, subst.size(), subst.data())); +} + +bool type_context::is_def_eq_args(expr const & e1, expr const & e2) { + lean_assert(is_app(e1) && is_app(e2)); + buffer args1, args2; + get_app_args(e1, args1); + get_app_args(e2, args2); + if (args1.size() != args2.size()) + return false; + for (unsigned i = 0; i < args1.size(); i++) { + if (!is_def_eq_core(args1[i], args2[i])) + return false; + } + return true; +} /* struct unification_hint_fn { diff --git a/src/library/type_context.h b/src/library/type_context.h index 18fa264c08..aa9016ef88 100644 --- a/src/library/type_context.h +++ b/src/library/type_context.h @@ -207,6 +207,7 @@ public: private: void init_core(transparency_mode m); optional reduce_projection(expr const & e); + bool should_unfold_macro(expr const & e); optional expand_macro(expr const & e); expr whnf_core(expr const & e); optional is_transparent(name const & n); @@ -217,6 +218,8 @@ private: void assign_tmp(level const & u, level const & l); void assign_tmp(expr const & m, expr const & v); + level mk_tmp_univ_mvar(); + /* ------------ Uniform interface to tmp/regular metavariables That is, in tmp mode they access m_tmp_eassignment and m_tmp_uassignment, @@ -232,10 +235,58 @@ public: void assign(level const & u, level const & l); void assign(expr const & m, expr const & v); +private: + /* ------------ + Type inference + ------------ */ + expr infer_core(expr const & e); + expr infer_local(expr const & e); + expr infer_metavar(expr const & e); + expr infer_constant(expr const & e); + expr infer_macro(expr const & e); + expr infer_lambda(expr e); + optional get_level_core(expr const & A); + level get_level(expr const & A); + expr infer_pi(expr e); + expr infer_app(expr const & e); + expr infer_let(expr e); + private: level instantiate(level const & l); expr instantiate(expr const & l); - bool is_def_eq(levels const & ls1, levels const & ls2); + optional is_delta(expr const & e); + + bool is_def_eq(levels const & ls1, levels const & ls2); + bool is_def_eq_core(expr const & t, expr const & s); + bool is_def_eq_binding(expr e1, expr e2); + bool is_def_eq_args(expr const & e1, expr const & e2); + +public: + /* Helper class for creating pushing local declarations on m_lctx */ + class tmp_locals { + type_context & m_ctx; + buffer m_locals; + public: + tmp_locals(type_context & ctx):m_ctx(ctx) {} + ~tmp_locals(); + + expr push_local(name const & pp_name, expr const & type, binder_info const & bi = binder_info()) { + expr r = m_ctx.push_local(pp_name, type, bi); + m_locals.push_back(r); + return r; + } + + expr push_let(name const & name, expr const & type, expr const & value) { + expr r = m_ctx.push_let(name, type, value); + m_locals.push_back(r); + return r; + } + + unsigned size() const { return m_locals.size(); } + expr const * data() const { return m_locals.data(); } + + buffer const & as_buffer() const { return m_locals; } + }; }; }