diff --git a/src/kernel/environment.cpp b/src/kernel/environment.cpp index 7757a90a55..57e0a0358e 100644 --- a/src/kernel/environment.cpp +++ b/src/kernel/environment.cpp @@ -9,6 +9,7 @@ Author: Leonardo de Moura #include #include "runtime/sstream.h" #include "runtime/thread.h" +#include "runtime/sharecommon.h" #include "util/map_foreach.h" #include "util/io.h" #include "kernel/environment.h" @@ -220,12 +221,15 @@ environment environment::add_theorem(declaration const & d, bool check) const { theorem_val const & v = d.to_theorem_val(); if (check) { type_checker checker(*this, diag.get()); - if (!checker.is_prop(v.get_type())) - throw theorem_type_is_not_prop(*this, v.get_name(), v.get_type()); + sharecommon_persistent_fn share; + expr val(share(v.get_value().raw())); + expr type(share(v.get_type().raw())); + if (!checker.is_prop(type)) + throw theorem_type_is_not_prop(*this, v.get_name(), type); check_constant_val(*this, v.to_constant_val(), checker); - check_no_metavar_no_fvar(*this, v.get_name(), v.get_value()); - expr val_type = checker.check(v.get_value(), v.get_lparams()); - if (!checker.is_def_eq(val_type, v.get_type())) + check_no_metavar_no_fvar(*this, v.get_name(), val); + expr val_type = checker.check(val, v.get_lparams()); + if (!checker.is_def_eq(val_type, type)) throw definition_type_mismatch_exception(*this, d, val_type); } return diag.update(add(constant_info(d))); diff --git a/src/kernel/expr_eq_fn.cpp b/src/kernel/expr_eq_fn.cpp index 647bbec2d3..875b4e2dc1 100644 --- a/src/kernel/expr_eq_fn.cpp +++ b/src/kernel/expr_eq_fn.cpp @@ -11,65 +11,49 @@ Author: Leonardo de Moura #include "kernel/expr.h" #include "kernel/expr_sets.h" -#ifndef LEAN_EQ_CACHE_CAPACITY -#define LEAN_EQ_CACHE_CAPACITY 1024*8 -#endif - namespace lean { -struct eq_cache { - struct entry { - object * m_a; - object * m_b; - entry():m_a(nullptr), m_b(nullptr) {} - }; - unsigned m_capacity; - std::vector m_cache; - std::vector m_used; - eq_cache():m_capacity(LEAN_EQ_CACHE_CAPACITY), m_cache(LEAN_EQ_CACHE_CAPACITY) {} +/** +\brief Functional object for comparing expressions. - bool check(expr const & a, expr const & b) { - if (!is_shared(a) || !is_shared(b)) - return false; - unsigned i = hash(hash(a), hash(b)) % m_capacity; - if (m_cache[i].m_a == a.raw() && m_cache[i].m_b == b.raw()) { - return true; - } else { - if (m_cache[i].m_a == nullptr) - m_used.push_back(i); - m_cache[i].m_a = a.raw(); - m_cache[i].m_b = b.raw(); - return false; - } - } - - void clear() { - for (unsigned i : m_used) - m_cache[i].m_a = nullptr; - m_used.clear(); - } -}; - -/* CACHE_RESET: No */ -MK_THREAD_LOCAL_GET_DEF(eq_cache, get_eq_cache); - -/** \brief Functional object for comparing expressions. - - Remark if CompareBinderInfo is true, then functional object will also compare - binder information attached to lambda and Pi expressions */ +Remark if CompareBinderInfo is true, then functional object will also compare +binder information attached to lambda and Pi expressions +*/ template class expr_eq_fn { - eq_cache & m_cache; - + struct key_hasher { + std::size_t operator()(std::pair const & p) const { + return hash((size_t)p.first >> 3, (size_t)p.first >> 3); + } + }; + typedef std::unordered_set, key_hasher> cache; + cache * m_cache = nullptr; + bool check_cache(expr const & a, expr const & b) { + if (!is_shared(a) || !is_shared(b)) + return false; + if (!m_cache) + m_cache = new cache(); + std::pair key(a.raw(), b.raw()); + if (m_cache->find(key) != m_cache->end()) + return true; + m_cache->insert(key); + return false; + } static void check_system() { ::lean::check_system("expression equality test"); } - - bool apply(expr const & a, expr const & b) { + bool apply(expr const & a, expr const & b, bool root = false) { if (is_eqp(a, b)) return true; if (hash(a) != hash(b)) return false; if (a.kind() != b.kind()) return false; - if (is_bvar(a)) return bvar_idx(a) == bvar_idx(b); - if (m_cache.check(a, b)) + switch (a.kind()) { + case expr_kind::BVar: return bvar_idx(a) == bvar_idx(b); + case expr_kind::Lit: return lit_value(a) == lit_value(b); + case expr_kind::MVar: return mvar_name(a) == mvar_name(b); + case expr_kind::FVar: return fvar_name(a) == fvar_name(b); + case expr_kind::Sort: return sort_level(a) == sort_level(b); + default: break; + } + if (!root && check_cache(a, b)) return true; /* We increase the number of heartbeats here because some code (e.g., `simp`) may spend a lot of time comparing @@ -78,6 +62,10 @@ class expr_eq_fn { lean_inc_heartbeat(); switch (a.kind()) { case expr_kind::BVar: + case expr_kind::Lit: + case expr_kind::MVar: + case expr_kind::FVar: + case expr_kind::Sort: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::MData: return @@ -88,16 +76,10 @@ class expr_eq_fn { apply(proj_expr(a), proj_expr(b)) && proj_sname(a) == proj_sname(b) && proj_idx(a) == proj_idx(b); - case expr_kind::Lit: - return lit_value(a) == lit_value(b); case expr_kind::Const: return const_name(a) == const_name(b) && compare(const_levels(a), const_levels(b), [](level const & l1, level const & l2) { return l1 == l2; }); - case expr_kind::MVar: - return mvar_name(a) == mvar_name(b); - case expr_kind::FVar: - return fvar_name(a) == fvar_name(b); case expr_kind::App: check_system(); return @@ -117,15 +99,13 @@ class expr_eq_fn { apply(let_value(a), let_value(b)) && apply(let_body(a), let_body(b)) && (!CompareBinderInfo || let_name(a) == let_name(b)); - case expr_kind::Sort: - return sort_level(a) == sort_level(b); } lean_unreachable(); // LCOV_EXCL_LINE } public: - expr_eq_fn():m_cache(get_eq_cache()) {} - ~expr_eq_fn() { m_cache.clear(); } - bool operator()(expr const & a, expr const & b) { return apply(a, b); } + expr_eq_fn() {} + ~expr_eq_fn() { if (m_cache) delete m_cache; } + bool operator()(expr const & a, expr const & b) { return apply(a, b, true); } }; bool is_equal(expr const & a, expr const & b) { diff --git a/src/kernel/replace_fn.cpp b/src/kernel/replace_fn.cpp index 65b9a07dc2..0dde49600e 100644 --- a/src/kernel/replace_fn.cpp +++ b/src/kernel/replace_fn.cpp @@ -14,7 +14,7 @@ namespace lean { class replace_rec_fn { struct key_hasher { std::size_t operator()(std::pair const & p) const { - return hash((size_t)p.first, p.second); + return hash((size_t)p.first >> 3, p.second); } }; std::unordered_map, expr, key_hasher> m_cache; diff --git a/src/runtime/sharecommon.cpp b/src/runtime/sharecommon.cpp index 47a3b80f41..53305df97d 100644 --- a/src/runtime/sharecommon.cpp +++ b/src/runtime/sharecommon.cpp @@ -4,10 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ -#include #include -#include -#include #include "runtime/sharecommon.h" #include "runtime/hash.h" @@ -294,6 +291,15 @@ lean_object * sharecommon_quick_fn::check_cache(lean_object * a) { it->second->m_rc++; return it->second; } + if (m_check_set) { + auto it = m_set.find(a); + if (it != m_set.end()) { + lean_object * result = *it; + lean_assert(lean_is_st(result)); + result->m_rc++; + return result; + } + } } return nullptr; } @@ -416,4 +422,14 @@ lean_object * sharecommon_quick_fn::visit(lean_object * a) { extern "C" LEAN_EXPORT obj_res lean_sharecommon_quick(obj_arg a) { return sharecommon_quick_fn()(a); } + +lean_object * sharecommon_persistent_fn::operator()(lean_object * e) { + lean_object * r = check_cache(e); + if (r != nullptr) + return r; + m_saved.push_back(object_ref(e, true)); + r = visit(e); + m_saved.push_back(object_ref(r, true)); + return r; +} }; diff --git a/src/runtime/sharecommon.h b/src/runtime/sharecommon.h index 1ac304182e..c88d632c9a 100644 --- a/src/runtime/sharecommon.h +++ b/src/runtime/sharecommon.h @@ -5,7 +5,10 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #pragma once -#include "runtime/object.h" +#include +#include +#include +#include "runtime/object_ref.h" namespace lean { extern "C" LEAN_EXPORT uint8 lean_sharecommon_eq(b_obj_arg o1, b_obj_arg o2); @@ -17,6 +20,7 @@ It optimizes the number of RC operations, the strategy for caching results, and uses C++ hashmap. */ class sharecommon_quick_fn { +protected: struct set_hash { std::size_t operator()(lean_object * o) const { return lean_sharecommon_hash(o); } }; @@ -31,6 +35,12 @@ class sharecommon_quick_fn { std::unordered_map m_cache; /* Set of maximally shared terms. AKA hash-consing table. */ std::unordered_set m_set; + /* + If `true`, `check_cache` will also check `m_set`. + This is useful when the input term may contain terms that have already + been hashconsed. + */ + bool m_check_set; lean_object * check_cache(lean_object * a); lean_object * save(lean_object * a, lean_object * new_a); @@ -39,8 +49,23 @@ class sharecommon_quick_fn { lean_object * visit_ctor(lean_object * a); lean_object * visit(lean_object * a); public: + sharecommon_quick_fn(bool s = false):m_check_set(s) {} + void set_check_set(bool f) { m_check_set = f; } lean_object * operator()(lean_object * a) { return visit(a); } }; + +/* +Similar to `sharecommon_quick_fn`, but we save the entry points and result values to ensure +they are not deleted. +*/ +class sharecommon_persistent_fn : private sharecommon_quick_fn { + std::vector m_saved; +public: + sharecommon_persistent_fn(bool s = false):sharecommon_quick_fn(s) {} + void set_check_set(bool f) { m_check_set = f; } + lean_object * operator()(lean_object * e); +}; + };