diff --git a/src/runtime/maxsharing.cpp b/src/runtime/maxsharing.cpp index 809e381110..546f8b3d7e 100644 --- a/src/runtime/maxsharing.cpp +++ b/src/runtime/maxsharing.cpp @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ +#include #include "runtime/object.h" #include "runtime/hash.h" @@ -151,94 +152,120 @@ public: template class max_sharing_fn { - state m_state; + state m_state; + std::vector m_children; + std::vector m_todo; - obj_res visit_closure(b_obj_arg a) { - // TODO(Leo) - lean_inc(a); - return a; + void clear_children() { + m_children.clear(); } - obj_res visit_array(b_obj_arg a) { - // TODO(Leo) - lean_inc(a); - return a; - } - - obj_res visit_sarray(b_obj_arg a) { - // TODO(Leo) - lean_inc(a); - return a; - } - - obj_res visit_string(b_obj_arg a) { - // TODO(Leo) - lean_inc(a); - return a; - } - - obj_res visit_mpz(b_obj_arg a) { - // TODO(Leo) - lean_inc(a); - return a; - } - - obj_res visit_thunk(b_obj_arg a) { - // TODO(Leo) - lean_inc(a); - return a; - } - - obj_res visit_ctor(b_obj_arg a) { - // TODO(Leo) - lean_inc(a); - return a; - } - - obj_res visit(obj_arg a) { + bool push_child(b_obj_arg a) { if (lean_is_scalar(a)) { - return a; + m_children.push_back(a); + return true; } + switch (lean_ptr_tag(a)) { + case LeanReserved: + lean_unreachable(); + // We do not maximize sharing for the following kinds of objects + case LeanMPZ: case LeanThunk: + case LeanTask: case LeanRef: + case LeanExternal: + m_children.push_back(a); + return true; + default: + break; + } + + // Check whether we have already maximized sharing for `a` obj_res o = m_state.map_find(a); if (o != lean_box(0)) { obj_res r = lean_ctor_get(o, 0); - lean_inc(r); lean_dec(o); - std::cout << "found cached:" << r << "\n"; - return r; - } - obj_res r; - switch (lean_ptr_tag(a)) { - case LeanClosure: r = visit_closure(a); break; - case LeanArray: r = visit_array(a); break; - case LeanScalarArray: r = visit_sarray(a); break; - case LeanString: r = visit_string(a); break; - case LeanMPZ: r = visit_mpz(a); break; - case LeanThunk: r = visit_thunk(a); break; - case LeanTask: return a; - case LeanRef: return a; - case LeanExternal: return a; - case LeanReserved: lean_unreachable(); - default: r = visit_ctor(a); break; + // The map still has a reference to `r` + m_children.push_back(r); + // std::cout << "cached maximized " << r << "\n"; + return true; } - obj_res opt_new_r = m_state.set_find(r); + m_todo.push_back(a); + return false; + } + + void save(b_obj_arg a, obj_arg new_a) { + lean_assert(m_todo.size() > 0); + lean_assert(m_todo.back() == a); + m_todo.pop_back(); + obj_res opt_new_r = m_state.set_find(new_a); if (opt_new_r != lean_box(0)) { - lean_dec(r); // we already have a maximally shared term equivalent to `r` - r = lean_ctor_get(opt_new_r, 0); - lean_inc_n(r, 2); + lean_dec(new_a); // we already have a maximally shared term equivalent to `new_a` + new_a = lean_ctor_get(opt_new_r, 0); + lean_inc(new_a); lean_dec(opt_new_r); - m_state.map_insert(a, r); - std::cout << "found shared:" << r << "\n"; - return r; + lean_inc(a); + m_state.map_insert(a, new_a); + // std::cout << "already maximized " << new_a << "\n"; + } else { + lean_inc(a); + lean_inc_n(new_a, 3); + m_state.set_insert(new_a); // `new_a` is a new maximally shared term + m_state.map_insert(a, new_a); // `new_a` is the maximally shared representation for `a` + m_state.map_insert(new_a, new_a); // `new_a` is the maximally shared representation for itself + // std::cout << "new maximized " << new_a << "\n"; } + } - lean_inc_n(r, 4); - m_state.set_insert(r); // r is a new maximally shared term - m_state.map_insert(a, r); // `r` is the maximally shared representation for `a` - m_state.map_insert(r, r); // `r` is the maximally shared representation of itself - std::cout << "new shared:" << r << " " << lean_maxsharing_hash(r) << "\n"; - return r; + void visit_closure(b_obj_arg a) { + // TODO(Leo) + lean_inc(a); + save(a, a); + } + + void visit_array(b_obj_arg a) { + // TODO(Leo) + lean_inc(a); + save(a, a); + } + + void visit_sarray(b_obj_arg a) { + // TODO(Leo) + lean_inc(a); + save(a, a); + } + + void visit_string(b_obj_arg a) { + lean_inc(a); + save(a, a); + } + + void visit_ctor(b_obj_arg a) { + clear_children(); + unsigned num_objs = lean_ctor_num_objs(a); + bool missing_child = false; + for (unsigned i = 0; i < num_objs; i++) { + if (!push_child(lean_ctor_get(a, i))) { + // std::cout << "missing_child " << a << " #" << i << " := " << lean_ctor_get(a, i) << std::endl; + missing_child = true; + lean_assert(m_todo.back() == lean_ctor_get(a, i)); + } + } + if (missing_child) + return; + unsigned tag = lean_ptr_tag(a); + unsigned sz = lean_object_byte_size(a); + lean_object * new_a = lean_alloc_small_object(sz); + lean_set_st_header(new_a, tag, num_objs); + for (unsigned i = 0; i < num_objs; i++) { + lean_inc(m_children[i]); + lean_ctor_set(new_a, i, m_children[i]); + } + unsigned scalar_offset = sizeof(lean_object) + num_objs*sizeof(void*); + if (scalar_offset < sz) { + unsigned scalar_sz = sz - scalar_offset; + memcpy(reinterpret_cast(new_a) + scalar_offset, reinterpret_cast(a) + scalar_offset, scalar_sz); + } + save(a, new_a); } public: @@ -246,7 +273,34 @@ public: } obj_res operator()(obj_arg a) { - return m_state.pack(visit(a)); + if (push_child(a)) { + return m_state.pack(a); + } + while (!m_todo.empty()) { + b_obj_arg curr = m_todo.back(); + // std::cout << "visiting " << curr << " " << static_cast(lean_ptr_tag(curr)) << "\n"; + switch (lean_ptr_tag(curr)) { + case LeanClosure: visit_closure(curr); break; + case LeanArray: visit_array(curr); break; + case LeanScalarArray: visit_sarray(curr); break; + case LeanString: visit_string(curr); break; + case LeanMPZ: lean_unreachable(); + case LeanThunk: lean_unreachable(); + case LeanTask: lean_unreachable(); + case LeanRef: lean_unreachable(); + case LeanExternal: lean_unreachable(); + case LeanReserved: lean_unreachable(); + default: visit_ctor(curr); break; + } + } + + obj_res o = m_state.map_find(a); + lean_assert(o != lean_box(0)); + obj_res r = lean_ctor_get(o, 0); + lean_inc(r); + lean_dec(o); + lean_dec(a); + return m_state.pack(r); } }; diff --git a/tests/lean/run/maxsharing.lean b/tests/lean/run/maxsharing.lean new file mode 100644 index 0000000000..ca888e59ab --- /dev/null +++ b/tests/lean/run/maxsharing.lean @@ -0,0 +1,86 @@ +def check (b : Bool) : IO Unit := +unless b $ throw $ IO.userError "check failed" + +unsafe def tst1 : IO Unit := do +let x := [1]; +let y := [0].map (fun x => x + 1); +let s := MaxSharing.State.empty; +check $ ptrAddrUnsafe x != ptrAddrUnsafe y; +let (x, s) := s.maxSharing x; +let (y, s) := s.maxSharing y; +check $ ptrAddrUnsafe x == ptrAddrUnsafe y; +let (z, s) := s.maxSharing [2]; +let (x, s) := s.maxSharing x; +check $ ptrAddrUnsafe x == ptrAddrUnsafe y; +check $ ptrAddrUnsafe x != ptrAddrUnsafe z; +IO.println x; +IO.println y; +IO.println z + +#eval tst1 + +unsafe def tst2 : IO Unit := do +let x := [1, 2]; +let y := [0, 1].map (fun x => x + 1); +check $ ptrAddrUnsafe x != ptrAddrUnsafe y; +let s := MaxSharing.State.empty; +let (x, s) := s.maxSharing x; +let (y, s) := s.maxSharing y; +check $ ptrAddrUnsafe x == ptrAddrUnsafe y; +let (z, s) := s.maxSharing [2]; +let (x, s) := s.maxSharing x; +check $ ptrAddrUnsafe x == ptrAddrUnsafe y; +check $ ptrAddrUnsafe x != ptrAddrUnsafe z; +IO.println x; +IO.println y; +IO.println z + +#eval tst2 + +structure Foo := +(x : Nat) +(y : Bool) +(z : Bool) + +@[noinline] def mkFoo1 (x : Nat) (z : Bool) : Foo := { x := x, y := true, z := z } +@[noinline] def mkFoo2 (x : Nat) (z : Bool) : Foo := { x := x, y := true, z := z } + +unsafe def tst3 : IO Unit := do +let o1 := mkFoo1 10 true; +let o2 := mkFoo2 10 true; +let o3 := mkFoo2 10 false; +check $ ptrAddrUnsafe o1 != ptrAddrUnsafe o2; +check $ ptrAddrUnsafe o1 != ptrAddrUnsafe o3; +let s := MaxSharing.State.empty; +let (o1, s) := s.maxSharing o1; +let (o2, s) := s.maxSharing o2; +let (o3, s) := s.maxSharing o3; +check $ o1.x == 10; +check $ o1.y == true; +check $ o1.z == true; +check $ o3.z == false; +check $ ptrAddrUnsafe o1 == ptrAddrUnsafe o2; +check $ ptrAddrUnsafe o1 != ptrAddrUnsafe o3; +IO.println o1.x; +pure () + +#eval tst3 + +unsafe def tst4 : IO Unit := do +let x := ["hello"]; +let y := ["ello"].map (fun x => "h" ++ x); +check $ ptrAddrUnsafe x != ptrAddrUnsafe y; +let s := MaxSharing.State.empty; +let (x, s) := s.maxSharing x; +let (y, s) := s.maxSharing y; +-- check $ ptrAddrUnsafe x == ptrAddrUnsafe y; +let (z, s) := s.maxSharing ["world"]; +let (x, s) := s.maxSharing x; +-- check $ ptrAddrUnsafe x == ptrAddrUnsafe y; +check $ ptrAddrUnsafe x != ptrAddrUnsafe z; +IO.println x; +IO.println y; +IO.println z + + +#eval tst3