perf: precise cache for expr_eq_fn (#4890)
This performance issue was exposed by the benchmarks at https://github.com/leanprover/LNSym/tree/proof_size_expt/Proofs/SHA512/Experiments
This commit is contained in:
parent
c517688f1d
commit
a856016b9d
5 changed files with 95 additions and 70 deletions
|
|
@ -9,6 +9,7 @@ Author: Leonardo de Moura
|
|||
#include <limits>
|
||||
#include "runtime/sstream.h"
|
||||
#include "runtime/thread.h"
|
||||
#include "runtime/sharecommon.h"
|
||||
#include "util/map_foreach.h"
|
||||
#include "util/io.h"
|
||||
#include "kernel/environment.h"
|
||||
|
|
@ -220,12 +221,15 @@ environment environment::add_theorem(declaration const & d, bool check) const {
|
|||
theorem_val const & v = d.to_theorem_val();
|
||||
if (check) {
|
||||
type_checker checker(*this, diag.get());
|
||||
if (!checker.is_prop(v.get_type()))
|
||||
throw theorem_type_is_not_prop(*this, v.get_name(), v.get_type());
|
||||
sharecommon_persistent_fn share;
|
||||
expr val(share(v.get_value().raw()));
|
||||
expr type(share(v.get_type().raw()));
|
||||
if (!checker.is_prop(type))
|
||||
throw theorem_type_is_not_prop(*this, v.get_name(), type);
|
||||
check_constant_val(*this, v.to_constant_val(), checker);
|
||||
check_no_metavar_no_fvar(*this, v.get_name(), v.get_value());
|
||||
expr val_type = checker.check(v.get_value(), v.get_lparams());
|
||||
if (!checker.is_def_eq(val_type, v.get_type()))
|
||||
check_no_metavar_no_fvar(*this, v.get_name(), val);
|
||||
expr val_type = checker.check(val, v.get_lparams());
|
||||
if (!checker.is_def_eq(val_type, type))
|
||||
throw definition_type_mismatch_exception(*this, d, val_type);
|
||||
}
|
||||
return diag.update(add(constant_info(d)));
|
||||
|
|
|
|||
|
|
@ -11,65 +11,49 @@ Author: Leonardo de Moura
|
|||
#include "kernel/expr.h"
|
||||
#include "kernel/expr_sets.h"
|
||||
|
||||
#ifndef LEAN_EQ_CACHE_CAPACITY
|
||||
#define LEAN_EQ_CACHE_CAPACITY 1024*8
|
||||
#endif
|
||||
|
||||
namespace lean {
|
||||
struct eq_cache {
|
||||
struct entry {
|
||||
object * m_a;
|
||||
object * m_b;
|
||||
entry():m_a(nullptr), m_b(nullptr) {}
|
||||
};
|
||||
unsigned m_capacity;
|
||||
std::vector<entry> m_cache;
|
||||
std::vector<unsigned> m_used;
|
||||
eq_cache():m_capacity(LEAN_EQ_CACHE_CAPACITY), m_cache(LEAN_EQ_CACHE_CAPACITY) {}
|
||||
/**
|
||||
\brief Functional object for comparing expressions.
|
||||
|
||||
bool check(expr const & a, expr const & b) {
|
||||
if (!is_shared(a) || !is_shared(b))
|
||||
return false;
|
||||
unsigned i = hash(hash(a), hash(b)) % m_capacity;
|
||||
if (m_cache[i].m_a == a.raw() && m_cache[i].m_b == b.raw()) {
|
||||
return true;
|
||||
} else {
|
||||
if (m_cache[i].m_a == nullptr)
|
||||
m_used.push_back(i);
|
||||
m_cache[i].m_a = a.raw();
|
||||
m_cache[i].m_b = b.raw();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void clear() {
|
||||
for (unsigned i : m_used)
|
||||
m_cache[i].m_a = nullptr;
|
||||
m_used.clear();
|
||||
}
|
||||
};
|
||||
|
||||
/* CACHE_RESET: No */
|
||||
MK_THREAD_LOCAL_GET_DEF(eq_cache, get_eq_cache);
|
||||
|
||||
/** \brief Functional object for comparing expressions.
|
||||
|
||||
Remark if CompareBinderInfo is true, then functional object will also compare
|
||||
binder information attached to lambda and Pi expressions */
|
||||
Remark if CompareBinderInfo is true, then functional object will also compare
|
||||
binder information attached to lambda and Pi expressions
|
||||
*/
|
||||
template<bool CompareBinderInfo>
|
||||
class expr_eq_fn {
|
||||
eq_cache & m_cache;
|
||||
|
||||
struct key_hasher {
|
||||
std::size_t operator()(std::pair<lean_object *, lean_object *> const & p) const {
|
||||
return hash((size_t)p.first >> 3, (size_t)p.first >> 3);
|
||||
}
|
||||
};
|
||||
typedef std::unordered_set<std::pair<lean_object *, lean_object *>, key_hasher> cache;
|
||||
cache * m_cache = nullptr;
|
||||
bool check_cache(expr const & a, expr const & b) {
|
||||
if (!is_shared(a) || !is_shared(b))
|
||||
return false;
|
||||
if (!m_cache)
|
||||
m_cache = new cache();
|
||||
std::pair<lean_object *, lean_object *> key(a.raw(), b.raw());
|
||||
if (m_cache->find(key) != m_cache->end())
|
||||
return true;
|
||||
m_cache->insert(key);
|
||||
return false;
|
||||
}
|
||||
static void check_system() {
|
||||
::lean::check_system("expression equality test");
|
||||
}
|
||||
|
||||
bool apply(expr const & a, expr const & b) {
|
||||
bool apply(expr const & a, expr const & b, bool root = false) {
|
||||
if (is_eqp(a, b)) return true;
|
||||
if (hash(a) != hash(b)) return false;
|
||||
if (a.kind() != b.kind()) return false;
|
||||
if (is_bvar(a)) return bvar_idx(a) == bvar_idx(b);
|
||||
if (m_cache.check(a, b))
|
||||
switch (a.kind()) {
|
||||
case expr_kind::BVar: return bvar_idx(a) == bvar_idx(b);
|
||||
case expr_kind::Lit: return lit_value(a) == lit_value(b);
|
||||
case expr_kind::MVar: return mvar_name(a) == mvar_name(b);
|
||||
case expr_kind::FVar: return fvar_name(a) == fvar_name(b);
|
||||
case expr_kind::Sort: return sort_level(a) == sort_level(b);
|
||||
default: break;
|
||||
}
|
||||
if (!root && check_cache(a, b))
|
||||
return true;
|
||||
/*
|
||||
We increase the number of heartbeats here because some code (e.g., `simp`) may spend a lot of time comparing
|
||||
|
|
@ -78,6 +62,10 @@ class expr_eq_fn {
|
|||
lean_inc_heartbeat();
|
||||
switch (a.kind()) {
|
||||
case expr_kind::BVar:
|
||||
case expr_kind::Lit:
|
||||
case expr_kind::MVar:
|
||||
case expr_kind::FVar:
|
||||
case expr_kind::Sort:
|
||||
lean_unreachable(); // LCOV_EXCL_LINE
|
||||
case expr_kind::MData:
|
||||
return
|
||||
|
|
@ -88,16 +76,10 @@ class expr_eq_fn {
|
|||
apply(proj_expr(a), proj_expr(b)) &&
|
||||
proj_sname(a) == proj_sname(b) &&
|
||||
proj_idx(a) == proj_idx(b);
|
||||
case expr_kind::Lit:
|
||||
return lit_value(a) == lit_value(b);
|
||||
case expr_kind::Const:
|
||||
return
|
||||
const_name(a) == const_name(b) &&
|
||||
compare(const_levels(a), const_levels(b), [](level const & l1, level const & l2) { return l1 == l2; });
|
||||
case expr_kind::MVar:
|
||||
return mvar_name(a) == mvar_name(b);
|
||||
case expr_kind::FVar:
|
||||
return fvar_name(a) == fvar_name(b);
|
||||
case expr_kind::App:
|
||||
check_system();
|
||||
return
|
||||
|
|
@ -117,15 +99,13 @@ class expr_eq_fn {
|
|||
apply(let_value(a), let_value(b)) &&
|
||||
apply(let_body(a), let_body(b)) &&
|
||||
(!CompareBinderInfo || let_name(a) == let_name(b));
|
||||
case expr_kind::Sort:
|
||||
return sort_level(a) == sort_level(b);
|
||||
}
|
||||
lean_unreachable(); // LCOV_EXCL_LINE
|
||||
}
|
||||
public:
|
||||
expr_eq_fn():m_cache(get_eq_cache()) {}
|
||||
~expr_eq_fn() { m_cache.clear(); }
|
||||
bool operator()(expr const & a, expr const & b) { return apply(a, b); }
|
||||
expr_eq_fn() {}
|
||||
~expr_eq_fn() { if (m_cache) delete m_cache; }
|
||||
bool operator()(expr const & a, expr const & b) { return apply(a, b, true); }
|
||||
};
|
||||
|
||||
bool is_equal(expr const & a, expr const & b) {
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ namespace lean {
|
|||
class replace_rec_fn {
|
||||
struct key_hasher {
|
||||
std::size_t operator()(std::pair<lean_object *, unsigned> const & p) const {
|
||||
return hash((size_t)p.first, p.second);
|
||||
return hash((size_t)p.first >> 3, p.second);
|
||||
}
|
||||
};
|
||||
std::unordered_map<std::pair<lean_object *, unsigned>, expr, key_hasher> m_cache;
|
||||
|
|
|
|||
|
|
@ -4,10 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include "runtime/sharecommon.h"
|
||||
#include "runtime/hash.h"
|
||||
|
||||
|
|
@ -294,6 +291,15 @@ lean_object * sharecommon_quick_fn::check_cache(lean_object * a) {
|
|||
it->second->m_rc++;
|
||||
return it->second;
|
||||
}
|
||||
if (m_check_set) {
|
||||
auto it = m_set.find(a);
|
||||
if (it != m_set.end()) {
|
||||
lean_object * result = *it;
|
||||
lean_assert(lean_is_st(result));
|
||||
result->m_rc++;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -416,4 +422,14 @@ lean_object * sharecommon_quick_fn::visit(lean_object * a) {
|
|||
extern "C" LEAN_EXPORT obj_res lean_sharecommon_quick(obj_arg a) {
|
||||
return sharecommon_quick_fn()(a);
|
||||
}
|
||||
|
||||
lean_object * sharecommon_persistent_fn::operator()(lean_object * e) {
|
||||
lean_object * r = check_cache(e);
|
||||
if (r != nullptr)
|
||||
return r;
|
||||
m_saved.push_back(object_ref(e, true));
|
||||
r = visit(e);
|
||||
m_saved.push_back(object_ref(r, true));
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -5,7 +5,10 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Author: Leonardo de Moura
|
||||
*/
|
||||
#pragma once
|
||||
#include "runtime/object.h"
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include "runtime/object_ref.h"
|
||||
|
||||
namespace lean {
|
||||
extern "C" LEAN_EXPORT uint8 lean_sharecommon_eq(b_obj_arg o1, b_obj_arg o2);
|
||||
|
|
@ -17,6 +20,7 @@ It optimizes the number of RC operations, the strategy for caching results,
|
|||
and uses C++ hashmap.
|
||||
*/
|
||||
class sharecommon_quick_fn {
|
||||
protected:
|
||||
struct set_hash {
|
||||
std::size_t operator()(lean_object * o) const { return lean_sharecommon_hash(o); }
|
||||
};
|
||||
|
|
@ -31,6 +35,12 @@ class sharecommon_quick_fn {
|
|||
std::unordered_map<lean_object *, lean_object *> m_cache;
|
||||
/* Set of maximally shared terms. AKA hash-consing table. */
|
||||
std::unordered_set<lean_object *, set_hash, set_eq> m_set;
|
||||
/*
|
||||
If `true`, `check_cache` will also check `m_set`.
|
||||
This is useful when the input term may contain terms that have already
|
||||
been hashconsed.
|
||||
*/
|
||||
bool m_check_set;
|
||||
|
||||
lean_object * check_cache(lean_object * a);
|
||||
lean_object * save(lean_object * a, lean_object * new_a);
|
||||
|
|
@ -39,8 +49,23 @@ class sharecommon_quick_fn {
|
|||
lean_object * visit_ctor(lean_object * a);
|
||||
lean_object * visit(lean_object * a);
|
||||
public:
|
||||
sharecommon_quick_fn(bool s = false):m_check_set(s) {}
|
||||
void set_check_set(bool f) { m_check_set = f; }
|
||||
lean_object * operator()(lean_object * a) {
|
||||
return visit(a);
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
Similar to `sharecommon_quick_fn`, but we save the entry points and result values to ensure
|
||||
they are not deleted.
|
||||
*/
|
||||
class sharecommon_persistent_fn : private sharecommon_quick_fn {
|
||||
std::vector<object_ref> m_saved;
|
||||
public:
|
||||
sharecommon_persistent_fn(bool s = false):sharecommon_quick_fn(s) {}
|
||||
void set_check_set(bool f) { m_check_set = f; }
|
||||
lean_object * operator()(lean_object * e);
|
||||
};
|
||||
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue