diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 0c481b51f9..74368d2e7b 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -487,6 +487,20 @@ struct type_checker::imp { return r; } expr whnf(expr const & t) { return m_conv->whnf(t, m_conv_ctx); } + void push() { + lean_assert(!m_cache_cs); + m_infer_type_cache[0].push(); + m_infer_type_cache[1].push(); + } + void pop() { + lean_assert(!m_cache_cs); + m_infer_type_cache[0].pop(); + m_infer_type_cache[1].pop(); + } + unsigned num_scopes() const { + lean_assert(m_infer_type_cache[0].num_scopes() == m_infer_type_cache[1].num_scopes()); + return m_infer_type_cache[0].num_scopes(); + } }; static add_cnstr_fn g_no_constraint_fn = mk_no_contranint_fn(); @@ -512,6 +526,9 @@ bool type_checker::is_prop(expr const & t) { return m_ptr->is_prop(t); } expr type_checker::whnf(expr const & t) { return m_ptr->whnf(t); } expr type_checker::ensure_pi(expr const & t, expr const & s) { return m_ptr->ensure_pi(t, s); } expr type_checker::ensure_sort(expr const & t, expr const & s) { return m_ptr->ensure_sort(t, s); } +void type_checker::push() { m_ptr->push(); } +void type_checker::pop() { m_ptr->pop(); } +unsigned type_checker::num_scopes() const { return m_ptr->num_scopes(); } static void check_no_metavar(environment const & env, expr const & e) { if (has_metavar(e)) diff --git a/src/kernel/type_checker.h b/src/kernel/type_checker.h index 54f1f0d9b2..fd8f98d158 100644 --- a/src/kernel/type_checker.h +++ b/src/kernel/type_checker.h @@ -100,6 +100,13 @@ public: /** \brief Mare sure type of \c e is a sort, and return it. Throw an exception otherwise. */ expr ensure_type(expr const & e) { return ensure_sort(infer(e), e); } + /** \brief Create a backtracking point for cache and generated constraints. */ + void push(); + /** \brief Restore backtracking point. */ + void pop(); + /** \brief Return the number of backtracking points. */ + unsigned num_scopes() const; + void swap(type_checker & tc) { std::swap(m_ptr, tc.m_ptr); } }; diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index 7f7f413f19..fb73e691c4 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -1830,7 +1830,7 @@ static void get_type_checker_args(lua_State * L, int idx, optional & extra_opaque = get_name_set_named_param(L, idx, "extra_opaque", name_set()); } -int mk_type_checker(lua_State * L) { +static int mk_type_checker(lua_State * L) { int nargs = lua_gettop(L); if (nargs == 1) { return push_type_checker_ref(L, std::make_shared(to_environment(L, 1))); @@ -1857,29 +1857,37 @@ int mk_type_checker(lua_State * L) { } } } -int type_checker_whnf(lua_State * L) { return push_expr(L, to_type_checker_ref(L, 1)->whnf(to_expr(L, 2))); } -int type_checker_ensure_pi(lua_State * L) { +static int type_checker_whnf(lua_State * L) { return push_expr(L, to_type_checker_ref(L, 1)->whnf(to_expr(L, 2))); } +static int type_checker_ensure_pi(lua_State * L) { if (lua_gettop(L) == 2) return push_expr(L, to_type_checker_ref(L, 1)->ensure_pi(to_expr(L, 2))); else return push_expr(L, to_type_checker_ref(L, 1)->ensure_pi(to_expr(L, 2), to_expr(L, 3))); } -int type_checker_ensure_sort(lua_State * L) { +static int type_checker_ensure_sort(lua_State * L) { if (lua_gettop(L) == 2) return push_expr(L, to_type_checker_ref(L, 1)->ensure_sort(to_expr(L, 2))); else return push_expr(L, to_type_checker_ref(L, 1)->ensure_sort(to_expr(L, 2), to_expr(L, 3))); } -int type_checker_check(lua_State * L) { +static int type_checker_check(lua_State * L) { int nargs = lua_gettop(L); if (nargs <= 2) return push_expr(L, to_type_checker_ref(L, 1)->check(to_expr(L, 2), level_param_names())); else return push_expr(L, to_type_checker_ref(L, 1)->check(to_expr(L, 2), to_level_param_names(L, 3))); } -int type_checker_infer(lua_State * L) { return push_expr(L, to_type_checker_ref(L, 1)->infer(to_expr(L, 2))); } -int type_checker_is_def_eq(lua_State * L) { return push_boolean(L, to_type_checker_ref(L, 1)->is_def_eq(to_expr(L, 2), to_expr(L, 3))); } -int type_checker_is_prop(lua_State * L) { return push_boolean(L, to_type_checker_ref(L, 1)->is_prop(to_expr(L, 2))); } +static int type_checker_infer(lua_State * L) { return push_expr(L, to_type_checker_ref(L, 1)->infer(to_expr(L, 2))); } +static int type_checker_is_def_eq(lua_State * L) { return push_boolean(L, to_type_checker_ref(L, 1)->is_def_eq(to_expr(L, 2), to_expr(L, 3))); } +static int type_checker_is_prop(lua_State * L) { return push_boolean(L, to_type_checker_ref(L, 1)->is_prop(to_expr(L, 2))); } +static int type_checker_push(lua_State * L) { to_type_checker_ref(L, 1)->push(); return 0; } +static int type_checker_pop(lua_State * L) { + if (to_type_checker_ref(L, 1)->num_scopes() == 0) + throw exception("invalid pop method, type_checker does not have backtracking points"); + to_type_checker_ref(L, 1)->pop(); + return 0; +} +static int type_checker_num_scopes(lua_State * L) { return push_integer(L, to_type_checker_ref(L, 1)->num_scopes()); } static const struct luaL_Reg type_checker_ref_m[] = { {"__gc", type_checker_ref_gc}, @@ -1890,6 +1898,9 @@ static const struct luaL_Reg type_checker_ref_m[] = { {"infer", safe_function}, {"is_def_eq", safe_function}, {"is_prop", safe_function}, + {"push", safe_function}, + {"pop", safe_function}, + {"num_scopes", safe_function}, {0, 0} }; diff --git a/tests/lua/tc8.lua b/tests/lua/tc8.lua new file mode 100644 index 0000000000..f5f9fe76de --- /dev/null +++ b/tests/lua/tc8.lua @@ -0,0 +1,22 @@ +local env = environment() +local N = Const("N") +env = add_decl(env, mk_var_decl("N", Type)) +env = add_decl(env, mk_var_decl("f", mk_arrow(N, N))) +env = add_decl(env, mk_var_decl("a", N)) +local f = Const("f") +local a = Const("a") +local m1 = mk_metavar("m1", mk_metavar("m2", mk_sort(mk_meta_univ("l")))) +local cs = {} +local ngen = name_generator("tst") +local tc = type_checker(env, ngen, function (c) print(c); cs[#cs+1] = c end) +assert(tc:num_scopes() == 0) +tc:push() +assert(tc:num_scopes() == 1) +print(tc:check(f(m1))) +assert(#cs == 1) +print(tc:check(f(f(m1)))) +assert(#cs == 1) -- New constraint is not generated +tc:pop() -- forget that we checked f(m1) +print(tc:check(f(m1))) +assert(#cs == 2) -- constraint is generated again +check_error(function() tc:pop() end)