diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index f5d9f02e43..3bf6717308 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -86,6 +86,25 @@ expr_cell::expr_cell(expr_kind k, unsigned h, bool has_expr_mv, bool has_univ_mv #endif } +expr_cell::expr_cell(expr_cell const & src): + m_kind(src.m_kind), + m_has_expr_mv(src.m_has_expr_mv), + m_has_univ_mv(src.m_has_univ_mv), + m_has_local(src.m_has_local), + m_has_param_univ(src.m_has_param_univ), + m_hash(src.m_hash), + m_rc(0) { + unsigned flgs = src.m_flags; + unsigned tag = src.m_tag; + m_flags = flgs; + m_tag = tag; + m_hash_alloc = g_hash_alloc_counter; + g_hash_alloc_counter++; + #ifdef LEAN_TRACK_LIVE_EXPRS + atomic_fetch_add_explicit(&g_num_live_exprs, 1u, memory_order_release); + #endif +} + void expr_cell::dec_ref(expr & e, buffer & todelete) { if (e.m_ptr) { expr_cell * c = e.steal_ptr(); @@ -160,23 +179,36 @@ expr_mlocal::expr_mlocal(bool is_meta, name const & n, expr const & t, tag g): 1, get_free_var_range(t), g), m_name(n), m_type(t) {} + void expr_mlocal::dealloc(buffer & todelete) { dec_ref(m_type, todelete); this->~expr_mlocal(); get_mlocal_allocator().recycle(this); } +expr_mlocal::expr_mlocal(expr_mlocal const & src, expr const & new_type): + expr_composite(src), m_name(src.m_name), m_type(new_type) {} + DEF_THREAD_MEMORY_POOL(get_local_allocator, sizeof(expr_local)); expr_local::expr_local(name const & n, name const & pp_name, expr const & t, binder_info const & bi, tag g): - expr_mlocal(false, n, t, g), - m_pp_name(pp_name), - m_bi(bi) {} + expr_mlocal(false, n, t, g), m_pp_name(pp_name), m_bi(bi) {} + +expr_local::expr_local(expr_local const & src, expr const & new_type): + expr_mlocal(src, new_type), m_pp_name(src.m_pp_name), m_bi(src.m_bi) {} + void expr_local::dealloc(buffer & todelete) { dec_ref(m_type, todelete); this->~expr_local(); get_local_allocator().recycle(this); } +expr_composite::expr_composite(expr_composite const & src): + expr_cell(src), + m_weight(src.m_weight), + m_depth(src.m_depth), + m_free_var_range(src.m_free_var_range) { +} + // Composite expressions expr_composite::expr_composite(expr_kind k, unsigned h, bool has_expr_mv, bool has_univ_mv, bool has_local, bool has_param_univ, unsigned w, unsigned fv_range, tag g): @@ -201,6 +233,10 @@ expr_app::expr_app(expr const & fn, expr const & arg, tag g): m_hash = ::lean::hash(m_hash, m_weight); m_hash = ::lean::hash(m_hash, m_depth); } + +expr_app::expr_app(expr_app const & src, expr const & new_fn, expr const & new_arg): + expr_composite(src), m_fn(new_fn), m_arg(new_arg) {} + void expr_app::dealloc(buffer & todelete) { dec_ref(m_fn, todelete); dec_ref(m_arg, todelete); @@ -236,6 +272,10 @@ expr_binding::expr_binding(expr_kind k, name const & n, expr const & t, expr con m_hash = ::lean::hash(m_hash, m_depth); lean_assert(k == expr_kind::Lambda || k == expr_kind::Pi); } + +expr_binding::expr_binding(expr_binding const & src, expr const & d, expr const & b): + expr_composite(src), m_binder(src.m_binder, d), m_body(b) {} + void expr_binding::dealloc(buffer & todelete) { dec_ref(m_body, todelete); dec_ref(m_binder.m_type, todelete); @@ -272,6 +312,10 @@ expr_let::expr_let(name const & n, expr const & t, expr const & v, expr const & m_hash = ::lean::hash(m_hash, m_weight); m_hash = ::lean::hash(m_hash, m_depth); } + +expr_let::expr_let(expr_let const & src, expr const & t, expr const & v, expr const & b): + expr_composite(src), m_name(src.m_name), m_type(t), m_value(v), m_body(b) {} + void expr_let::dealloc(buffer & todelete) { dec_ref(m_body, todelete); dec_ref(m_value, todelete); @@ -318,6 +362,14 @@ static unsigned get_free_var_range(unsigned num, expr const * args) { return r; } +expr_macro::expr_macro(expr_macro const & src, expr const * new_args): + expr_composite(src), + m_definition(src.m_definition), + m_num_args(src.m_num_args) { + expr * data = get_args_ptr(); + std::uninitialized_copy(new_args, new_args + m_num_args, data); +} + expr_macro::expr_macro(macro_definition const & m, unsigned num, expr const * args, tag g): expr_composite(expr_kind::Macro, lean::hash(num, [&](unsigned i) { return args[i].hash(); }, m.hash()), @@ -354,17 +406,119 @@ expr_macro::~expr_macro() {} LEAN_THREAD_VALUE(bool, g_expr_cache_enabled, false); typedef typename std::unordered_set expr_cache; MK_THREAD_LOCAL_GET_DEF(expr_cache, get_expr_cache); -inline expr cache(expr const & e) { - if (g_expr_cache_enabled) { - expr_cache & cache = get_expr_cache(); - auto it = cache.find(e); - if (it != cache.end()) { - return *it; + +struct cache_expr_insert_fn { + expr_cache & m_cache; + cache_expr_insert_fn(expr_cache & c):m_cache(c) {} + + expr insert_macro(expr const & e) { + buffer new_args; + bool updated = false; + unsigned num = macro_num_args(e); + for (unsigned i = 0; i < num; i++) { + expr const & arg = macro_arg(e, i); + new_args.push_back(insert(arg)); + if (!is_eqp(arg, new_args.back())) + updated = true; + } + if (updated) { + char * mem = new char[sizeof(expr_macro) + num*sizeof(expr const *)]; + return expr(new (mem) expr_macro(*to_macro(e), new_args.data())); } else { - cache.insert(e); + return e; } } - return e; + + expr insert_meta(expr const & e) { + expr new_type = insert(mlocal_type(e)); + if (is_eqp(new_type, mlocal_type(e))) { + return e; + } else { + return expr(new (get_mlocal_allocator().allocate()) expr_mlocal(*to_mlocal(e), new_type)); + } + } + + expr insert_local(expr const & e) { + expr new_type = insert(mlocal_type(e)); + if (is_eqp(new_type, mlocal_type(e))) { + return e; + } else { + return expr(new (get_local_allocator().allocate()) expr_local(*to_local(e), new_type)); + } + } + + expr insert_constant(expr const & e) { + /* TODO(Leo): similar insert for levels */ + return e; + } + + expr insert_sort(expr const & e) { + /* TODO(Leo): similar insert for levels */ + return e; + } + + expr insert_app(expr const & e) { + expr new_fn = insert(app_fn(e)); + expr new_arg = insert(app_arg(e)); + if (is_eqp(new_fn, app_fn(e)) && is_eqp(new_arg, app_arg(e))) { + return e; + } else { + return expr(new (get_app_allocator().allocate()) expr_app(*to_app(e), new_fn, new_arg)); + } + } + + expr insert_binding(expr const & e) { + expr new_domain = insert(binding_domain(e)); + expr new_body = insert(binding_body(e)); + if (is_eqp(new_domain, binding_domain(e)) && is_eqp(new_body, binding_body(e))) { + return e; + } else { + return expr(new (get_binding_allocator().allocate()) expr_binding(*to_binding(e), new_domain, new_body)); + } + } + + expr insert_let(expr const & e) { + expr new_type = insert(let_type(e)); + expr new_value = insert(let_value(e)); + expr new_body = insert(let_body(e)); + if (is_eqp(new_type, let_type(e)) && is_eqp(new_value, let_value(e)) && is_eqp(new_body, let_body(e))) { + return e; + } else { + return expr(new (get_let_allocator().allocate()) expr_let(*to_let(e), new_type, new_value, new_body)); + } + return e; + } + + expr insert(expr const & e) { + auto it = m_cache.find(e); + if (it != m_cache.end()) { + return *it; + } + expr new_e; + switch (e.kind()) { + case expr_kind::Var: new_e = e; break; + case expr_kind::Macro: new_e = insert_macro(e); break; + case expr_kind::Meta: new_e = insert_meta(e); break; + case expr_kind::Local: new_e = insert_local(e); break; + case expr_kind::Constant: new_e = insert_constant(e); break; + case expr_kind::Sort: new_e = insert_sort(e); break; + case expr_kind::App: new_e = insert_app(e); break; + case expr_kind::Lambda: new_e = insert_binding(e); break; + case expr_kind::Pi: new_e = insert_binding(e); break; + case expr_kind::Let: new_e = insert_let(e); break; + } + m_cache.insert(new_e); + return new_e; + } + + expr operator()(expr const & e) { return insert(e); } +}; + +inline expr cache(expr const & e) { + if (g_expr_cache_enabled) + return cache_expr_insert_fn(get_expr_cache())(e); + else + return e; } bool enable_expr_caching(bool f) { DEBUG_CODE(bool r1 =) enable_level_caching(f); diff --git a/src/kernel/expr.h b/src/kernel/expr.h index 3ae8717baf..8fb9373410 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -67,7 +67,8 @@ protected: void set_is_arrow(bool flag); friend bool is_arrow(expr const & e); - static void dec_ref(expr & c, buffer & todelete); + static void dec_ref(expr & c, buffer & todelete); + expr_cell(expr_cell const & src); // for hash_consing public: expr_cell(expr_kind k, unsigned h, bool has_expr_mv, bool has_univ_mv, bool has_local, bool has_param_univ, tag g); expr_kind kind() const { return static_cast(m_kind); } @@ -94,6 +95,7 @@ private: expr_cell * m_ptr; explicit expr(expr_cell * ptr):m_ptr(ptr) { if (m_ptr) m_ptr->inc_ref(); } friend class expr_cell; + friend struct cache_expr_insert_fn; expr_cell * steal_ptr() { expr_cell * r = m_ptr; m_ptr = nullptr; return r; } friend class optional; public: @@ -172,6 +174,8 @@ class expr_const : public expr_cell { levels m_levels; friend expr_cell; void dealloc(); + friend struct cache_expr_insert_fn; + expr_const(expr_const const &, levels const & new_levels); // for hash_consing public: expr_const(name const & n, levels const & ls, tag g); name const & get_name() const { return m_name; } @@ -187,6 +191,7 @@ protected: friend unsigned get_weight(expr const & e); friend unsigned get_depth(expr const & e); friend unsigned get_free_var_range(expr const & e); + expr_composite(expr_composite const & src); // for hash_consing public: expr_composite(expr_kind k, unsigned h, bool has_expr_mv, bool has_univ_mv, bool has_local, bool has_param_univ, unsigned w, unsigned fv_range, tag g); @@ -199,6 +204,8 @@ protected: expr m_type; friend expr_cell; void dealloc(buffer & todelete); + friend struct cache_expr_insert_fn; + expr_mlocal(expr_mlocal const &, expr const & new_type); // for hash_consing public: expr_mlocal(bool is_meta, name const & n, expr const & t, tag g); name const & get_name() const { return m_name; } @@ -254,6 +261,8 @@ class expr_local : public expr_mlocal { binder_info m_bi; friend expr_cell; void dealloc(buffer & todelete); + friend struct cache_expr_insert_fn; + expr_local(expr_local const &, expr const & new_type); // for hash_consing public: expr_local(name const & n, name const & pp_name, expr const & t, binder_info const & bi, tag g); name const & get_pp_name() const { return m_pp_name; } @@ -266,6 +275,8 @@ class expr_app : public expr_composite { expr m_arg; friend expr_cell; void dealloc(buffer & todelete); + friend struct cache_expr_insert_fn; + expr_app(expr_app const &, expr const & new_fn, expr const & new_arg); // for hash_consing public: expr_app(expr const & fn, expr const & arg, tag g); expr const & get_fn() const { return m_fn; } @@ -277,6 +288,8 @@ class binder { name m_name; expr m_type; binder_info m_info; + binder(binder const & src, expr const & new_type): // for hash_consing + m_name(src.m_name), m_type(new_type), m_info(src.m_info) {} public: binder(name const & n, expr const & t, binder_info const & bi): m_name(n), m_type(t), m_info(bi) {} @@ -292,6 +305,8 @@ class expr_binding : public expr_composite { expr m_body; friend class expr_cell; void dealloc(buffer & todelete); + friend struct cache_expr_insert_fn; + expr_binding(expr_binding const &, expr const & new_domain, expr const & new_body); // for hash_consing public: expr_binding(expr_kind k, name const & n, expr const & t, expr const & e, binder_info const & i, tag g); @@ -310,6 +325,8 @@ class expr_let : public expr_composite { expr m_body; friend class expr_cell; void dealloc(buffer & todelete); + friend struct cache_expr_insert_fn; + expr_let(expr_let const &, expr const & new_type, expr const & new_value, expr const & new_body); // for hash_consing public: expr_let(name const & n, expr const & t, expr const & v, expr const & b, tag g); name const & get_name() const { return m_name; } @@ -323,6 +340,8 @@ class expr_sort : public expr_cell { level m_level; friend expr_cell; void dealloc(); + friend struct cache_expr_insert_fn; + expr_sort(expr_sort const &, level const & new_level); // for hash_consing public: expr_sort(level const & l, tag g); ~expr_sort(); @@ -402,6 +421,8 @@ class expr_macro : public expr_composite { expr const * get_args_ptr() const { return reinterpret_cast(reinterpret_cast(this)+sizeof(expr_macro)); } + friend struct cache_expr_insert_fn; + expr_macro(expr_macro const & src, expr const * new_args); // for hash_consing public: expr_macro(macro_definition const & v, unsigned num, expr const * args, tag g); ~expr_macro();