diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 5380da01c7..0c481b51f9 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -10,6 +10,7 @@ Author: Leonardo de Moura #include "util/lbool.h" #include "util/flet.h" #include "util/sstream.h" +#include "util/scoped_map.h" #include "kernel/type_checker.h" #include "kernel/expr_maps.h" #include "kernel/instantiate.h" @@ -33,6 +34,8 @@ add_cnstr_fn mk_no_contranint_fn() { /** \brief Auxiliary functional object used to implement type checker. */ struct type_checker::imp { + typedef scoped_map cache; + /** \brief Interface type_checker <-> converter */ class converter_context : public converter::context { imp & m_imp; @@ -64,16 +67,63 @@ struct type_checker::imp { // Examples: // The type of (lambda x : A, t) is (Pi x : A, typeof(t)) // The type of (lambda {x : A}, t) is (Pi {x : A}, typeof(t)) - expr_bi_struct_map m_infer_type_cache[2]; + cache m_infer_type_cache[2]; converter_context m_conv_ctx; type_checker_context m_tc_ctx; bool m_memoize; // temp flag level_param_names m_params; + buffer m_cs; // temporary cache of constraints + bool m_cache_cs; // true if we should cache the constraints; false if we should send to m_add_cnstr_fn + + // Auxiliary object used to restore cache and filter constraints + // when a failure occurs in the type checker. + // That is, we should not keep cached results, and we should not sent constraints + // when a failure occurs. + struct scope { + imp & m_imp; + unsigned m_old_cs_size; + bool m_old_cache_cs; + bool m_keep; + scope(imp & i):m_imp(i), m_old_cs_size(m_imp.m_cs.size()), m_old_cache_cs(m_imp.m_cache_cs), m_keep(false) { + m_imp.m_infer_type_cache[0].push(); + m_imp.m_infer_type_cache[1].push(); + m_imp.m_cache_cs = true; + } + ~scope() { + if (m_keep) { + // keep results + m_imp.m_infer_type_cache[0].keep(); + m_imp.m_infer_type_cache[1].keep(); + } else { + // restore caches + m_imp.m_infer_type_cache[0].pop(); + m_imp.m_infer_type_cache[1].pop(); + m_imp.m_cs.shrink(m_old_cs_size); + } + m_imp.m_cache_cs = m_old_cache_cs; + } + void keep() { + m_keep = true; + if (!m_old_cache_cs) { + lean_assert(m_old_cs_size == 0); + // send results to m_add_cnstr_fn + try { + for (auto const & c : m_imp.m_cs) + m_imp.m_add_cnstr_fn(c); + } catch (...) { + m_imp.m_cs.clear(); + throw; + } + m_imp.m_cs.clear(); + } + } + }; imp(environment const & env, name_generator const & g, add_cnstr_fn const & h, std::unique_ptr && conv, bool memoize): m_env(env), m_gen(g), m_add_cnstr_fn(h), m_conv(std::move(conv)), m_conv_ctx(*this), m_tc_ctx(*this), - m_memoize(memoize) {} + m_memoize(memoize), m_cache_cs(false) { + } optional expand_macro(expr const & m) { lean_assert(is_macro(m)); @@ -91,17 +141,10 @@ struct type_checker::imp { /** \brief Add given constraint using m_add_cnstr_fn. */ void add_cnstr(constraint const & c) { - m_add_cnstr_fn(c); - } - - /** \brief Return true iff \c t and \c s are definitionally equal */ - bool is_def_eq(expr const & t, expr const & s, delayed_justification & jst) { - return m_conv->is_def_eq(t, s, m_conv_ctx, jst); - } - - /** \brief Return true iff \c e is a proposition */ - bool is_prop(expr const & e) { - return whnf(infer_type(e)) == Bool; + if (m_cache_cs) + m_cs.push_back(c); + else + m_add_cnstr_fn(c); } /** @@ -157,7 +200,7 @@ struct type_checker::imp { \remark \c s is used to extract position (line number information) when an error message is produced */ - expr ensure_sort(expr e, expr const & s) { + expr ensure_sort_core(expr e, expr const & s) { if (is_sort(e)) return e; e = whnf(e); @@ -181,7 +224,7 @@ struct type_checker::imp { } /** \brief Similar to \c ensure_sort, but makes sure \c e "is" a Pi. */ - expr ensure_pi(expr e, expr const & s) { + expr ensure_pi_core(expr e, expr const & s) { if (is_pi(e)) return e; e = whnf(e); @@ -339,16 +382,16 @@ struct type_checker::imp { case expr_kind::Lambda: { if (!infer_only) { expr t = infer_type_core(binding_domain(e), infer_only); - ensure_sort(t, binding_domain(e)); + ensure_sort_core(t, binding_domain(e)); } auto b = open_binding_body(e); r = mk_pi(binding_name(e), binding_domain(e), abstract_local(infer_type_core(b.first, infer_only), b.second), binding_info(e)); break; } case expr_kind::Pi: { - expr t1 = ensure_sort(infer_type_core(binding_domain(e), infer_only), binding_domain(e)); + expr t1 = ensure_sort_core(infer_type_core(binding_domain(e), infer_only), binding_domain(e)); auto b = open_binding_body(e); - expr t2 = ensure_sort(infer_type_core(b.first, infer_only), binding_body(e)); + expr t2 = ensure_sort_core(infer_type_core(b.first, infer_only), binding_body(e)); if (m_env.impredicative()) r = mk_sort(mk_imax(sort_level(t1), sort_level(t2))); else @@ -356,7 +399,7 @@ struct type_checker::imp { break; } case expr_kind::App: { - expr f_type = ensure_pi(infer_type_core(app_fn(e), infer_only), app_fn(e)); + expr f_type = ensure_pi_core(infer_type_core(app_fn(e), infer_only), app_fn(e)); if (!infer_only) { expr a_type = infer_type_core(app_arg(e), infer_only); app_delayed_jst jst(m_env, e, f_type, a_type); @@ -373,7 +416,7 @@ struct type_checker::imp { } case expr_kind::Let: if (!infer_only) { - ensure_sort(infer_type_core(let_type(e), infer_only), let_type(e)); + ensure_sort_core(infer_type_core(let_type(e), infer_only), let_type(e)); expr val_type = infer_type_core(let_value(e), infer_only); simple_delayed_justification jst([=]() { return mk_let_mismatch_jst(e, val_type); }); if (!is_def_eq(val_type, let_type(e), jst)) { @@ -394,16 +437,55 @@ struct type_checker::imp { return r; } - expr infer_type(expr const & e) { return infer_type_core(e, true); } - expr check(expr const & e, level_param_names const & ps) { - flet updt(m_params, ps); - return infer_type_core(e, false); + expr infer_type(expr const & e) { + scope mk_scope(*this); + expr r = infer_type_core(e, true); + mk_scope.keep(); + return r; + } + expr check(expr const & e, level_param_names const & ps) { + scope mk_scope(*this); + flet updt(m_params, ps); + expr r = infer_type_core(e, false); + mk_scope.keep(); + return r; + } + expr ensure_sort(expr const & e, expr const & s) { + scope mk_scope(*this); + expr r = ensure_sort_core(e, s); + mk_scope.keep(); + return r; + } + expr ensure_pi(expr const & e, expr const & s) { + scope mk_scope(*this); + expr r = ensure_pi_core(e, s); + mk_scope.keep(); + return r; + } + /** \brief Return true iff \c t and \c s are definitionally equal */ + bool is_def_eq(expr const & t, expr const & s, delayed_justification & jst) { + scope mk_scope(*this); + bool r = m_conv->is_def_eq(t, s, m_conv_ctx, jst); + if (r) mk_scope.keep(); + return r; + } + bool is_def_eq(expr const & t, expr const & s) { + scope mk_scope(*this); + bool r = m_conv->is_def_eq(t, s, m_conv_ctx); + if (r) mk_scope.keep(); + return r; } - bool is_def_eq(expr const & t, expr const & s) { return m_conv->is_def_eq(t, s, m_conv_ctx); } bool is_def_eq(expr const & t, expr const & s, justification const & j) { as_delayed_justification djst(j); return is_def_eq(t, s, djst); } + /** \brief Return true iff \c e is a proposition */ + bool is_prop(expr const & e) { + scope mk_scope(*this); + bool r = whnf(infer_type(e)) == Bool; + if (r) mk_scope.keep(); + return r; + } expr whnf(expr const & t) { return m_conv->whnf(t, m_conv_ctx); } }; diff --git a/tests/lua/expr9.lua b/tests/lua/expr9.lua index 295bec3b22..84216e80c1 100644 --- a/tests/lua/expr9.lua +++ b/tests/lua/expr9.lua @@ -5,11 +5,16 @@ print(env:normalize(Fun(a, m))) print(env:normalize(Fun(a, m(a)))) local m2 = mk_metavar("m2", mk_arrow(Bool, Bool, Bool)) print(env:normalize(Fun(a, (m2(a))(a)))) +print("step1") env:type_check(m) +print("step2") env:type_check(Fun(a, m(a))) +print("step3") env:type_check(Fun(a, (m2(a))(a))) local m3 = mk_metavar("m3", mk_metavar("m4", mk_sort(mk_meta_univ("l")))) +print("step4") env:type_check(m3) +print("step5") -- The following call fails, because the type checker will try to -- create a constraint, but constraint generation is not supported by -- the type checker used to implement the method type_check @@ -17,7 +22,4 @@ assert(not pcall(function() env:type_check(m3(a)) end )) - - - - +print("before end") diff --git a/tests/lua/tc1.lua b/tests/lua/tc1.lua index da0801458d..f1f4c4c4f5 100644 --- a/tests/lua/tc1.lua +++ b/tests/lua/tc1.lua @@ -7,14 +7,19 @@ local t = Fun(a, Bool, a) local b = Const("b") print(t(b)) assert(tc:whnf(t(b)) == b) -local cs = {} -local tc2 = type_checker(env, g, function (c) print(c); cs[#cs+1] = c end) assert(tc:check(Bool) == mk_sort(mk_level_one())) print(tc:infer(t)) local m = mk_metavar("m1", mk_metavar("m2", mk_sort(mk_meta_univ("u")))) print(tc:infer(m)) + +local cs = {} +local tc2 = type_checker(env, g, function (c) print(c); cs[#cs+1] = c end) local t2 = Fun(a, Bool, m(a)) +print("---------") +print("t2: ") print(t2) +print("check(t): ") print(tc2:check(t)) +print("check(t2): ") print(tc2:check(t2)) assert(#cs == 2)