diff --git a/src/runtime/object.cpp b/src/runtime/object.cpp index 19c4d8aec2..f9be8fe9bd 100644 --- a/src/runtime/object.cpp +++ b/src/runtime/object.cpp @@ -103,7 +103,7 @@ void del(object * o) { case object_kind::MPZ: dealloc_mpz(o); break; case object_kind::Thunk: - dec(to_thunk(o)->m_closure, todo); + if (object * c = to_thunk(o)->m_closure) dec(c, todo); if (object * v = to_thunk(o)->m_value) dec(v, todo); free(o); break; diff --git a/src/runtime/object.h b/src/runtime/object.h index a4fb527c9f..85d08b3a2a 100644 --- a/src/runtime/object.h +++ b/src/runtime/object.h @@ -111,7 +111,7 @@ struct mpz_object : public object { struct thunk_object : public object { object * m_closure; atomic m_value; - thunk_object(object * c); + thunk_object(object * c, bool is_value = false); }; /* Base class for wrapping external_object data. @@ -330,15 +330,26 @@ inline mpz const & mpz_value(object * o) { return to_mpz(o)->m_value; } /* Thunks */ -inline thunk_object::thunk_object(object * c): - object(object_kind::Thunk), m_closure(c), m_value(nullptr) { - /* Remark: the implementation relies on the fact that nullptr is not a valid lean object. */ - lean_assert(is_closure(c)); +inline thunk_object::thunk_object(object * c, bool is_value): + object(object_kind::Thunk) { + if (is_value) { + m_closure = nullptr; + m_value = c; + } else { + lean_assert(is_closure(c)); + m_closure = c; + m_value = nullptr; + } } /* Remark: `c`'s RC is not modified. Result object has RC == 1. */ inline object * mk_thunk(object * c) { - return new (malloc(sizeof(thunk_object))) thunk_object(c); // NOLINT + return new (malloc(sizeof(thunk_object))) thunk_object(c, false); // NOLINT +} + +/* Remark: `v`'s RC is not modified. Result object has RC == 1. */ +inline object * mk_thunk_from_value(object * v) { + return new (malloc(sizeof(thunk_object))) thunk_object(v, true); // NOLINT } object * apply_1(object * f, object * a1); diff --git a/src/runtime/serializer.cpp b/src/runtime/serializer.cpp index 08ef51cec2..978a65a6e8 100644 --- a/src/runtime/serializer.cpp +++ b/src/runtime/serializer.cpp @@ -100,6 +100,11 @@ void serializer::write_closure(object *) { // NOLINT throw exception("serializer for closures has not been implemented yet"); } +void serializer::write_thunk(object * o) { + object * r = thunk_get(o); + write_object(r); +} + void serializer::write_array(object * o) { lean_assert(is_array(o)); size_t sz = sarray_size(o); @@ -162,6 +167,7 @@ void serializer::write_object(object * o) { switch (k) { case object_kind::Constructor: write_constructor(o); break; case object_kind::Closure: write_closure(o); break; + case object_kind::Thunk: write_thunk(o); break; case object_kind::Array: write_array(o); break; case object_kind::ScalarArray: write_scalar_array(o); break; case object_kind::String: write_string_object(o); break; @@ -282,6 +288,12 @@ object * deserializer::read_closure() { throw exception("serializer for closures has not been implemented yet"); } +object * deserializer::read_thunk() { + object * v = read_object(); + inc(v); + return mk_thunk_from_value(v); +} + object * deserializer::read_array() { size_t sz = read_size_t(); object * r = alloc_array(sz, sz); @@ -334,6 +346,7 @@ object * deserializer::read_object() { switch (k) { case object_kind::Constructor: r = read_constructor(); break; case object_kind::Closure: r = read_closure(); break; + case object_kind::Thunk: r = read_thunk(); break; case object_kind::Array: r = read_array(); break; case object_kind::ScalarArray: r = read_scalar_array(); break; case object_kind::String: r = read_string_object(); break; diff --git a/src/runtime/serializer.h b/src/runtime/serializer.h index 544234161b..9fee31c456 100644 --- a/src/runtime/serializer.h +++ b/src/runtime/serializer.h @@ -23,6 +23,7 @@ class serializer { std::unordered_map, std::equal_to> m_obj_table; void write_constructor(object * o); void write_closure(object * o); + void write_thunk(object * o); void write_array(object * o); void write_scalar_array(object * o); void write_string_object(object * o); @@ -62,6 +63,7 @@ class deserializer { unsigned read_unsigned_ext(); object * read_constructor(); object * read_closure(); + object * read_thunk(); object * read_array(); object * read_scalar_array(); object * read_string_object(); diff --git a/src/tests/util/object.cpp b/src/tests/util/object.cpp index 32fd4b2933..88bf48629d 100644 --- a/src/tests/util/object.cpp +++ b/src/tests/util/object.cpp @@ -5,6 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include +#include "runtime/serializer.h" #include "util/test.h" #include "util/object_ref.h" #include "util/init_module.h" @@ -85,6 +86,26 @@ static void tst4() { lean_assert(string_eq(r4, "hello world")); } +static void tst5() { + object_ref c(alloc_closure(r, 1, 0)); + object_ref t = mk_thunk_ref(c); + std::ostringstream out; + serializer s(out); + object_ref o(mk_string("bla bla")); + s.write_object(o.raw()); + s.write_object(t.raw()); + s.write_object(t.raw()); + std::istringstream in(out.str()); + deserializer d(in); + d.read_object(); + object * r1 = d.read_object(); + object * r2 = d.read_object(); + lean_assert(r1 == r2); + lean_assert(is_thunk(r1)); + object * str = thunk_get(r1); + lean_assert(strcmp(string_data(str), "hello world") == 0); +} + int main() { save_stack_info(); initialize_util_module(); @@ -92,6 +113,7 @@ int main() { tst2(); tst3(); tst4(); + tst5(); finalize_util_module(); return has_violations() ? 1 : 0; }