diff --git a/src/library/compiler/emit_cpp.cpp b/src/library/compiler/emit_cpp.cpp index ecc1708f0a..a6507ac622 100644 --- a/src/library/compiler/emit_cpp.cpp +++ b/src/library/compiler/emit_cpp.cpp @@ -909,6 +909,10 @@ static void emit_initialize(std::ostream & out, environment const & env, module_ expr const & code = d.snd(); if (!is_lambda(code)) { out << " " << to_cpp_name(env, n) << " = " << to_cpp_init_name(env, n) << "();\n"; + expr type = get_constant_ll_type(env, n); + if (is_pi(type) || is_enf_object_type(type)) { + out << "lean::mark_persistent(" << to_cpp_name(env, n) << ");\n"; + } } } out << "}\n"; diff --git a/src/runtime/object.cpp b/src/runtime/object.cpp index b3d65c027f..0bbf5d688a 100644 --- a/src/runtime/object.cpp +++ b/src/runtime/object.cpp @@ -760,14 +760,14 @@ void deactivate_task(task_object * t) { g_task_manager->deactivate_task(t); } -void to_mt(object * o); -static obj_res to_mt_fn(obj_arg o) { - to_mt(o); +void mark_mt(object * o); +static obj_res mark_mt_fn(obj_arg o) { + mark_mt(o); dec(o); return box(0); } -void to_mt(object * o) { +void mark_mt(object * o) { if (is_scalar(o) || !is_st_heap_obj(o)) return; o->m_mem_kind = static_cast(object_memory_kind::MTHeap); @@ -780,47 +780,47 @@ void to_mt(object * o) { case object_kind::PArrayPush: case object_kind::PArraySet: case object_kind::PArrayRoot: - /* `to_mt` cannot be used with parray. They must be copied when used in multiple threads. */ + /* `mark_mt` cannot be used with parray. They must be copied when used in multiple threads. */ lean_unreachable(); return; case object_kind::External: { - object * fn = alloc_closure(reinterpret_cast(to_mt_fn), 1, 0); + object * fn = alloc_closure(reinterpret_cast(mark_mt_fn), 1, 0); to_external(o)->for_each_nested(fn); dec(fn); return; } case object_kind::Task: - to_mt(task_get(o)); + mark_mt(task_get(o)); return; case object_kind::Constructor: { object ** it = cnstr_obj_cptr(o); object ** end = it + cnstr_num_objs(o); - for (; it != end; ++it) to_mt(*it); - break; + for (; it != end; ++it) mark_mt(*it); + return; } case object_kind::Closure: { object ** it = closure_arg_cptr(o); object ** end = it + closure_num_fixed(o); - for (; it != end; ++it) to_mt(*it); - break; + for (; it != end; ++it) mark_mt(*it); + return; } case object_kind::Array: { object ** it = array_cptr(o); object ** end = it + array_size(o); - for (; it != end; ++it) to_mt(*it); - break; + for (; it != end; ++it) mark_mt(*it); + return; } case object_kind::Thunk: - if (object * c = to_thunk(o)->m_closure) to_mt(c); - if (object * v = to_thunk(o)->m_value) to_mt(v); - break; + if (object * c = to_thunk(o)->m_closure) mark_mt(c); + if (object * v = to_thunk(o)->m_value) mark_mt(v); + return; } } task_object::task_object(obj_arg c, unsigned prio): object(object_kind::Task, object_memory_kind::MTHeap), m_value(nullptr), m_imp(new imp(c, prio)) { lean_assert(is_closure(c)); - to_mt(c); + mark_mt(c); } task_object::task_object(obj_arg v): @@ -960,6 +960,69 @@ b_obj_res io_wait_any_core(b_obj_arg task_list) { return g_task_manager->wait_any(task_list); } +void mark_persistent(object * o); +static obj_res mark_persistent_fn(obj_arg o) { + mark_persistent(o); + return box(0); +} + +void mark_persistent(object * o) { + if (is_scalar(o) || !is_heap_obj(o)) return; + o->m_mem_kind = static_cast(object_memory_kind::Persistent); + + switch (get_kind(o)) { + case object_kind::ScalarArray: + case object_kind::String: + case object_kind::MPZ: + return; + case object_kind::PArrayPop: + mark_persistent(to_parray(o)->m_next); + return; + case object_kind::PArrayPush: + case object_kind::PArraySet: + mark_persistent(to_parray(o)->m_elem); + mark_persistent(to_parray(o)->m_next); + return; + case object_kind::PArrayRoot: { + object ** it = to_parray(o)->m_data; + object ** end = it + to_parray(o)->m_size; + for (; it != end; ++it) mark_persistent(*it); + return; + } + case object_kind::External: { + object * fn = alloc_closure(reinterpret_cast(mark_persistent_fn), 1, 0); + to_external(o)->for_each_nested(fn); + dec(fn); + return; + } + case object_kind::Task: + mark_persistent(task_get(o)); + return; + case object_kind::Constructor: { + object ** it = cnstr_obj_cptr(o); + object ** end = it + cnstr_num_objs(o); + for (; it != end; ++it) mark_persistent(*it); + return; + } + case object_kind::Closure: { + object ** it = closure_arg_cptr(o); + object ** end = it + closure_num_fixed(o); + for (; it != end; ++it) mark_persistent(*it); + return; + } + case object_kind::Array: { + object ** it = array_cptr(o); + object ** end = it + array_size(o); + for (; it != end; ++it) mark_persistent(*it); + return; + } + case object_kind::Thunk: + if (object * c = to_thunk(o)->m_closure) mark_persistent(c); + if (object * v = to_thunk(o)->m_value) mark_persistent(v); + return; + } +} + // ======================================= // Natural numbers diff --git a/src/runtime/object.h b/src/runtime/object.h index 2974639864..0b9222cc6a 100644 --- a/src/runtime/object.h +++ b/src/runtime/object.h @@ -388,6 +388,9 @@ inline void obj_set_data(object * o, size_t offset, T v) { *(reinterpret_cast(reinterpret_cast(o) + offset)) = v; } +/* Mark all objects reachable from `o` as persistent */ +void mark_persistent(object * o); + // ======================================= // Constructor auxiliary functions