diff --git a/src/kernel/normalize.cpp b/src/kernel/normalize.cpp index 4f23dd8ec4..0c8b037923 100644 --- a/src/kernel/normalize.cpp +++ b/src/kernel/normalize.cpp @@ -9,6 +9,7 @@ Author: Leonardo de Moura #include "expr.h" #include "context.h" #include "environment.h" +#include "scoped_map.h" #include "builtin.h" #include "free_vars.h" #include "list.h" @@ -28,6 +29,7 @@ class svalue { expr m_expr; value_stack m_ctx; public: + svalue() {} explicit svalue(expr const & e): m_kind(svalue_kind::Expr), m_expr(e) {} explicit svalue(unsigned k): m_kind(svalue_kind::BoundedVar), m_bvar(k) {} svalue(expr const & e, value_stack const & c):m_kind(svalue_kind::Closure), m_expr(e), m_ctx(c) { lean_assert(is_lambda(e)); } @@ -52,8 +54,11 @@ value_stack extend(value_stack const & s, svalue const & v) { return cons(v, s); /** \brief Expression normalizer. */ class normalize_fn { + typedef scoped_map cache; + environment const & m_env; context const & m_ctx; + cache m_cache; svalue lookup(value_stack const & s, unsigned i, unsigned k) { unsigned j = i; @@ -127,20 +132,33 @@ class normalize_fn { /** \brief Normalize the expression \c a in a context composed of stack \c s and \c k binders. */ svalue normalize(expr const & a, value_stack const & s, unsigned k) { lean_trace("normalize", tout << "Normalize, k: " << k << "\n" << a << "\n";); + + bool shared = false; + if (is_shared(a)) { + shared = true; + auto it = m_cache.find(a); + if (it != m_cache.end()) + return it->second; + } + + svalue r; switch (a.kind()) { case expr_kind::Var: - return lookup(s, var_idx(a), k); + r = lookup(s, var_idx(a), k); + break; case expr_kind::Constant: { named_object const & obj = m_env.get_object(const_name(a)); if (obj.is_definition() && !obj.is_opaque()) { - return normalize(obj.get_value(), value_stack(), 0); + r = normalize(obj.get_value(), value_stack(), 0); } else { - return svalue(a); + r = svalue(a); } + break; } case expr_kind::Type: case expr_kind::Value: - return svalue(a); + r = svalue(a); + break; case expr_kind::App: { svalue f = normalize(arg(a, 0), s, k); unsigned i = 1; @@ -150,10 +168,15 @@ class normalize_fn { // beta reduction expr const & fv = to_expr(f); lean_trace("normalize", tout << "beta reduction...\n" << fv << "\n";); - value_stack new_s = extend(stack_of(f), normalize(arg(a, i), s, k)); - f = normalize(abst_body(fv), new_s, k); - if (i == n - 1) - return f; + { + cache::mk_scope sc(m_cache); + value_stack new_s = extend(stack_of(f), normalize(arg(a, i), s, k)); + f = normalize(abst_body(fv), new_s, k); + } + if (i == n - 1) { + r = f; + break; + } i++; } else { buffer new_args; @@ -162,37 +185,55 @@ class normalize_fn { for (; i < n; i++) new_args.push_back(reify(normalize(arg(a, i), s, k), k)); if (is_value(new_f)) { - expr r; - if (to_value(new_f).normalize(new_args.size(), new_args.data(), r)) - return svalue(r); + expr m; + if (to_value(new_f).normalize(new_args.size(), new_args.data(), m)) { + r = svalue(m); + break; + } } - return svalue(mk_app(new_args.size(), new_args.data())); + r = svalue(mk_app(new_args.size(), new_args.data())); + break; } } + break; } case expr_kind::Eq: { - expr new_l = reify(normalize(eq_lhs(a), s, k), k); - expr new_r = reify(normalize(eq_rhs(a), s, k), k); - if (new_l == new_r) { - return svalue(mk_bool_value(true)); - } else if (is_value(new_l) && is_value(new_r)) { - return svalue(mk_bool_value(false)); + expr new_lhs = reify(normalize(eq_lhs(a), s, k), k); + expr new_rhs = reify(normalize(eq_rhs(a), s, k), k); + if (new_lhs == new_rhs) { + r = svalue(mk_bool_value(true)); + } else if (is_value(new_lhs) && is_value(new_rhs)) { + r = svalue(mk_bool_value(false)); } else { - return svalue(mk_eq(new_l, new_r)); + r = svalue(mk_eq(new_lhs, new_rhs)); } + break; } case expr_kind::Lambda: - return svalue(a, s); + r = svalue(a, s); + break; case expr_kind::Pi: { expr new_t = reify(normalize(abst_domain(a), s, k), k); - expr new_b = reify(normalize(abst_body(a), extend(s, svalue(k)), k+1), k+1); - return svalue(mk_pi(abst_name(a), new_t, new_b)); + expr new_b; + { + cache::mk_scope sc(m_cache); + new_b = reify(normalize(abst_body(a), extend(s, svalue(k)), k+1), k+1); + } + r = svalue(mk_pi(abst_name(a), new_t, new_b)); + break; } - case expr_kind::Let: - return normalize(let_body(a), extend(s, normalize(let_value(a), s, k)), k+1); + case expr_kind::Let: { + svalue v = normalize(let_value(a), s, k); + { + cache::mk_scope sc(m_cache); + r = normalize(let_body(a), extend(s, v), k+1); + } + break; + }} + if (shared) { + m_cache.insert(a, r); } - lean_unreachable(); - return svalue(a); + return r; } public: diff --git a/src/kernel/type_check.cpp b/src/kernel/type_check.cpp index 8e2f48c324..df8a475728 100644 --- a/src/kernel/type_check.cpp +++ b/src/kernel/type_check.cpp @@ -125,7 +125,7 @@ struct infer_type_fn { lean_trace("type_check", tout << "infer type\n" << e << "\n" << ctx << "\n";); bool shared = false; - if (true && is_shared(e)) { + if (is_shared(e)) { shared = true; auto it = m_cache.find(e); if (it != m_cache.end()) diff --git a/src/tests/kernel/type_check.cpp b/src/tests/kernel/type_check.cpp index eb86a0994d..c44a2698a2 100644 --- a/src/tests/kernel/type_check.cpp +++ b/src/tests/kernel/type_check.cpp @@ -184,6 +184,27 @@ static void tst10() { std::cout << env.get_object("simp_eq").pp(env) << "\n"; } +static void tst11() { + environment env = mk_toplevel(); + env.add_var("f", Int >> (Int >> Int)); + env.add_var("a", Int); + unsigned n = 1000; + expr f = Const("f"); + expr a = Const("a"); + expr t1 = f(a,a); + expr b = Const("a"); + expr t2 = f(a,a); + expr t3 = f(b,b); + for (unsigned i = 0; i < n; i++) { + t1 = f(t1,t1); + t2 = mk_let("x", t2, f(Var(0), Var(0))); + t3 = f(t3,t3); + } + lean_assert(t1 != t2); + env.add_theorem("eqs1", Eq(t1,t2), Refl(Int, t1)); + env.add_theorem("eqs2", Eq(t1,t3), Refl(Int, t1)); +} + int main() { tst1(); tst2(); @@ -195,5 +216,6 @@ int main() { tst8(); tst9(); tst10(); + tst11(); return has_violations() ? 1 : 0; } diff --git a/src/util/scoped_map.h b/src/util/scoped_map.h index aebf1ec2dd..122c914314 100644 --- a/src/util/scoped_map.h +++ b/src/util/scoped_map.h @@ -98,7 +98,6 @@ public: m_actions.push_back(std::make_pair(action_kind::Replace, *it)); it->second = v; } - lean_assert(m_map.find(k)->second == v); } void insert(value_type const & p) {