diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index cab34b53fd..3bbb8022ee 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -24,6 +24,7 @@ class type_checker::imp { environment const & m_env; cache m_cache; normalizer m_normalizer; + context m_ctx; metavar_env * m_menv; unsigned m_menv_timestamp; unification_problems * m_up; @@ -202,6 +203,13 @@ class type_checker::imp { return m_normalizer.is_convertible(t1, t2, ctx, m_menv, m_up); } + void set_ctx(context const & ctx) { + if (!is_eqp(m_ctx, ctx)) { + clear(); + m_ctx = ctx; + } + } + void set_menv(metavar_env * menv) { if (m_menv == menv) { // Check whether m_menv has been updated since the last time the normalizer has been invoked @@ -227,18 +235,21 @@ public: } level infer_universe(expr const & t, context const & ctx, metavar_env * menv, unification_problems * up) { + set_ctx(ctx); set_menv(menv); flet set(m_up, up); return infer_universe_core(t, ctx); } expr infer_type(expr const & e, context const & ctx, metavar_env * menv, unification_problems * up) { + set_ctx(ctx); set_menv(menv); flet set(m_up, up); return infer_type_core(e, ctx); } bool is_convertible(expr const & t1, expr const & t2, context const & ctx, metavar_env * menv, unification_problems * up) { + set_ctx(ctx); set_menv(menv); flet set(m_up, up); return is_convertible_core(t1, t2, ctx); @@ -252,6 +263,7 @@ public: void clear() { m_cache.clear(); m_normalizer.clear(); + m_ctx = context(); m_menv = nullptr; m_menv_timestamp = 0; } diff --git a/src/tests/kernel/type_checker.cpp b/src/tests/kernel/type_checker.cpp index 76ce595119..d36ce2042d 100644 --- a/src/tests/kernel/type_checker.cpp +++ b/src/tests/kernel/type_checker.cpp @@ -267,6 +267,34 @@ static void tst14() { } } +static void tst15() { + environment env; + import_all(env); + context ctx1, ctx2; + expr A = Const("A"); + expr vec1 = Const("vec1"); + expr vec2 = Const("vec2"); + env.add_var("vec1", Int >> (Type() >> Type())); + env.add_var("vec2", Real >> (Type() >> Type())); + ctx1 = extend(ctx1, "x", Int, iVal(1)); + ctx1 = extend(ctx1, "f", Pi({A, Int}, vec1(A, Int))); + ctx2 = extend(ctx2, "x", Real, rVal(2)); + ctx2 = extend(ctx2, "f", Pi({A, Real}, vec2(A, Real))); + expr F = Var(0)(Var(1)); + expr F_copy = F; + type_checker checker(env); + std::cout << checker.infer_type(F, ctx1) << "\n"; + lean_assert_eq(checker.infer_type(F, ctx1), vec1(Var(1), Int)); + lean_assert_eq(checker.infer_type(F, ctx2), vec2(Var(1), Real)); + lean_assert(is_eqp(checker.infer_type(F, ctx2), checker.infer_type(F, ctx2))); + lean_assert(is_eqp(checker.infer_type(F, ctx1), checker.infer_type(F, ctx1))); + expr r = checker.infer_type(F, ctx1); + checker.clear(); + lean_assert(!is_eqp(r, checker.infer_type(F, ctx1))); + r = checker.infer_type(F, ctx1); + lean_assert(is_eqp(r, checker.infer_type(F, ctx1))); +} + int main() { tst1(); tst2(); @@ -282,5 +310,6 @@ int main() { tst12(); tst13(); tst14(); + tst15(); return has_violations() ? 1 : 0; }