feat: maxsharing for constructors
This commit is contained in:
parent
d099599b19
commit
46e8d193ca
2 changed files with 214 additions and 74 deletions
|
|
@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include <vector>
|
||||
#include "runtime/object.h"
|
||||
#include "runtime/hash.h"
|
||||
|
||||
|
|
@ -151,94 +152,120 @@ public:
|
|||
|
||||
template<typename state>
|
||||
class max_sharing_fn {
|
||||
state m_state;
|
||||
state m_state;
|
||||
std::vector<lean_object*> m_children;
|
||||
std::vector<lean_object*> m_todo;
|
||||
|
||||
obj_res visit_closure(b_obj_arg a) {
|
||||
// TODO(Leo)
|
||||
lean_inc(a);
|
||||
return a;
|
||||
void clear_children() {
|
||||
m_children.clear();
|
||||
}
|
||||
|
||||
obj_res visit_array(b_obj_arg a) {
|
||||
// TODO(Leo)
|
||||
lean_inc(a);
|
||||
return a;
|
||||
}
|
||||
|
||||
obj_res visit_sarray(b_obj_arg a) {
|
||||
// TODO(Leo)
|
||||
lean_inc(a);
|
||||
return a;
|
||||
}
|
||||
|
||||
obj_res visit_string(b_obj_arg a) {
|
||||
// TODO(Leo)
|
||||
lean_inc(a);
|
||||
return a;
|
||||
}
|
||||
|
||||
obj_res visit_mpz(b_obj_arg a) {
|
||||
// TODO(Leo)
|
||||
lean_inc(a);
|
||||
return a;
|
||||
}
|
||||
|
||||
obj_res visit_thunk(b_obj_arg a) {
|
||||
// TODO(Leo)
|
||||
lean_inc(a);
|
||||
return a;
|
||||
}
|
||||
|
||||
obj_res visit_ctor(b_obj_arg a) {
|
||||
// TODO(Leo)
|
||||
lean_inc(a);
|
||||
return a;
|
||||
}
|
||||
|
||||
obj_res visit(obj_arg a) {
|
||||
bool push_child(b_obj_arg a) {
|
||||
if (lean_is_scalar(a)) {
|
||||
return a;
|
||||
m_children.push_back(a);
|
||||
return true;
|
||||
}
|
||||
switch (lean_ptr_tag(a)) {
|
||||
case LeanReserved:
|
||||
lean_unreachable();
|
||||
// We do not maximize sharing for the following kinds of objects
|
||||
case LeanMPZ: case LeanThunk:
|
||||
case LeanTask: case LeanRef:
|
||||
case LeanExternal:
|
||||
m_children.push_back(a);
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
// Check whether we have already maximized sharing for `a`
|
||||
obj_res o = m_state.map_find(a);
|
||||
if (o != lean_box(0)) {
|
||||
obj_res r = lean_ctor_get(o, 0);
|
||||
lean_inc(r);
|
||||
lean_dec(o);
|
||||
std::cout << "found cached:" << r << "\n";
|
||||
return r;
|
||||
}
|
||||
obj_res r;
|
||||
switch (lean_ptr_tag(a)) {
|
||||
case LeanClosure: r = visit_closure(a); break;
|
||||
case LeanArray: r = visit_array(a); break;
|
||||
case LeanScalarArray: r = visit_sarray(a); break;
|
||||
case LeanString: r = visit_string(a); break;
|
||||
case LeanMPZ: r = visit_mpz(a); break;
|
||||
case LeanThunk: r = visit_thunk(a); break;
|
||||
case LeanTask: return a;
|
||||
case LeanRef: return a;
|
||||
case LeanExternal: return a;
|
||||
case LeanReserved: lean_unreachable();
|
||||
default: r = visit_ctor(a); break;
|
||||
// The map still has a reference to `r`
|
||||
m_children.push_back(r);
|
||||
// std::cout << "cached maximized " << r << "\n";
|
||||
return true;
|
||||
}
|
||||
|
||||
obj_res opt_new_r = m_state.set_find(r);
|
||||
m_todo.push_back(a);
|
||||
return false;
|
||||
}
|
||||
|
||||
void save(b_obj_arg a, obj_arg new_a) {
|
||||
lean_assert(m_todo.size() > 0);
|
||||
lean_assert(m_todo.back() == a);
|
||||
m_todo.pop_back();
|
||||
obj_res opt_new_r = m_state.set_find(new_a);
|
||||
if (opt_new_r != lean_box(0)) {
|
||||
lean_dec(r); // we already have a maximally shared term equivalent to `r`
|
||||
r = lean_ctor_get(opt_new_r, 0);
|
||||
lean_inc_n(r, 2);
|
||||
lean_dec(new_a); // we already have a maximally shared term equivalent to `new_a`
|
||||
new_a = lean_ctor_get(opt_new_r, 0);
|
||||
lean_inc(new_a);
|
||||
lean_dec(opt_new_r);
|
||||
m_state.map_insert(a, r);
|
||||
std::cout << "found shared:" << r << "\n";
|
||||
return r;
|
||||
lean_inc(a);
|
||||
m_state.map_insert(a, new_a);
|
||||
// std::cout << "already maximized " << new_a << "\n";
|
||||
} else {
|
||||
lean_inc(a);
|
||||
lean_inc_n(new_a, 3);
|
||||
m_state.set_insert(new_a); // `new_a` is a new maximally shared term
|
||||
m_state.map_insert(a, new_a); // `new_a` is the maximally shared representation for `a`
|
||||
m_state.map_insert(new_a, new_a); // `new_a` is the maximally shared representation for itself
|
||||
// std::cout << "new maximized " << new_a << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
lean_inc_n(r, 4);
|
||||
m_state.set_insert(r); // r is a new maximally shared term
|
||||
m_state.map_insert(a, r); // `r` is the maximally shared representation for `a`
|
||||
m_state.map_insert(r, r); // `r` is the maximally shared representation of itself
|
||||
std::cout << "new shared:" << r << " " << lean_maxsharing_hash(r) << "\n";
|
||||
return r;
|
||||
void visit_closure(b_obj_arg a) {
|
||||
// TODO(Leo)
|
||||
lean_inc(a);
|
||||
save(a, a);
|
||||
}
|
||||
|
||||
void visit_array(b_obj_arg a) {
|
||||
// TODO(Leo)
|
||||
lean_inc(a);
|
||||
save(a, a);
|
||||
}
|
||||
|
||||
void visit_sarray(b_obj_arg a) {
|
||||
// TODO(Leo)
|
||||
lean_inc(a);
|
||||
save(a, a);
|
||||
}
|
||||
|
||||
void visit_string(b_obj_arg a) {
|
||||
lean_inc(a);
|
||||
save(a, a);
|
||||
}
|
||||
|
||||
void visit_ctor(b_obj_arg a) {
|
||||
clear_children();
|
||||
unsigned num_objs = lean_ctor_num_objs(a);
|
||||
bool missing_child = false;
|
||||
for (unsigned i = 0; i < num_objs; i++) {
|
||||
if (!push_child(lean_ctor_get(a, i))) {
|
||||
// std::cout << "missing_child " << a << " #" << i << " := " << lean_ctor_get(a, i) << std::endl;
|
||||
missing_child = true;
|
||||
lean_assert(m_todo.back() == lean_ctor_get(a, i));
|
||||
}
|
||||
}
|
||||
if (missing_child)
|
||||
return;
|
||||
unsigned tag = lean_ptr_tag(a);
|
||||
unsigned sz = lean_object_byte_size(a);
|
||||
lean_object * new_a = lean_alloc_small_object(sz);
|
||||
lean_set_st_header(new_a, tag, num_objs);
|
||||
for (unsigned i = 0; i < num_objs; i++) {
|
||||
lean_inc(m_children[i]);
|
||||
lean_ctor_set(new_a, i, m_children[i]);
|
||||
}
|
||||
unsigned scalar_offset = sizeof(lean_object) + num_objs*sizeof(void*);
|
||||
if (scalar_offset < sz) {
|
||||
unsigned scalar_sz = sz - scalar_offset;
|
||||
memcpy(reinterpret_cast<char*>(new_a) + scalar_offset, reinterpret_cast<char*>(a) + scalar_offset, scalar_sz);
|
||||
}
|
||||
save(a, new_a);
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
@ -246,7 +273,34 @@ public:
|
|||
}
|
||||
|
||||
obj_res operator()(obj_arg a) {
|
||||
return m_state.pack(visit(a));
|
||||
if (push_child(a)) {
|
||||
return m_state.pack(a);
|
||||
}
|
||||
while (!m_todo.empty()) {
|
||||
b_obj_arg curr = m_todo.back();
|
||||
// std::cout << "visiting " << curr << " " << static_cast<unsigned>(lean_ptr_tag(curr)) << "\n";
|
||||
switch (lean_ptr_tag(curr)) {
|
||||
case LeanClosure: visit_closure(curr); break;
|
||||
case LeanArray: visit_array(curr); break;
|
||||
case LeanScalarArray: visit_sarray(curr); break;
|
||||
case LeanString: visit_string(curr); break;
|
||||
case LeanMPZ: lean_unreachable();
|
||||
case LeanThunk: lean_unreachable();
|
||||
case LeanTask: lean_unreachable();
|
||||
case LeanRef: lean_unreachable();
|
||||
case LeanExternal: lean_unreachable();
|
||||
case LeanReserved: lean_unreachable();
|
||||
default: visit_ctor(curr); break;
|
||||
}
|
||||
}
|
||||
|
||||
obj_res o = m_state.map_find(a);
|
||||
lean_assert(o != lean_box(0));
|
||||
obj_res r = lean_ctor_get(o, 0);
|
||||
lean_inc(r);
|
||||
lean_dec(o);
|
||||
lean_dec(a);
|
||||
return m_state.pack(r);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
86
tests/lean/run/maxsharing.lean
Normal file
86
tests/lean/run/maxsharing.lean
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
def check (b : Bool) : IO Unit :=
|
||||
unless b $ throw $ IO.userError "check failed"
|
||||
|
||||
unsafe def tst1 : IO Unit := do
|
||||
let x := [1];
|
||||
let y := [0].map (fun x => x + 1);
|
||||
let s := MaxSharing.State.empty;
|
||||
check $ ptrAddrUnsafe x != ptrAddrUnsafe y;
|
||||
let (x, s) := s.maxSharing x;
|
||||
let (y, s) := s.maxSharing y;
|
||||
check $ ptrAddrUnsafe x == ptrAddrUnsafe y;
|
||||
let (z, s) := s.maxSharing [2];
|
||||
let (x, s) := s.maxSharing x;
|
||||
check $ ptrAddrUnsafe x == ptrAddrUnsafe y;
|
||||
check $ ptrAddrUnsafe x != ptrAddrUnsafe z;
|
||||
IO.println x;
|
||||
IO.println y;
|
||||
IO.println z
|
||||
|
||||
#eval tst1
|
||||
|
||||
unsafe def tst2 : IO Unit := do
|
||||
let x := [1, 2];
|
||||
let y := [0, 1].map (fun x => x + 1);
|
||||
check $ ptrAddrUnsafe x != ptrAddrUnsafe y;
|
||||
let s := MaxSharing.State.empty;
|
||||
let (x, s) := s.maxSharing x;
|
||||
let (y, s) := s.maxSharing y;
|
||||
check $ ptrAddrUnsafe x == ptrAddrUnsafe y;
|
||||
let (z, s) := s.maxSharing [2];
|
||||
let (x, s) := s.maxSharing x;
|
||||
check $ ptrAddrUnsafe x == ptrAddrUnsafe y;
|
||||
check $ ptrAddrUnsafe x != ptrAddrUnsafe z;
|
||||
IO.println x;
|
||||
IO.println y;
|
||||
IO.println z
|
||||
|
||||
#eval tst2
|
||||
|
||||
structure Foo :=
|
||||
(x : Nat)
|
||||
(y : Bool)
|
||||
(z : Bool)
|
||||
|
||||
@[noinline] def mkFoo1 (x : Nat) (z : Bool) : Foo := { x := x, y := true, z := z }
|
||||
@[noinline] def mkFoo2 (x : Nat) (z : Bool) : Foo := { x := x, y := true, z := z }
|
||||
|
||||
unsafe def tst3 : IO Unit := do
|
||||
let o1 := mkFoo1 10 true;
|
||||
let o2 := mkFoo2 10 true;
|
||||
let o3 := mkFoo2 10 false;
|
||||
check $ ptrAddrUnsafe o1 != ptrAddrUnsafe o2;
|
||||
check $ ptrAddrUnsafe o1 != ptrAddrUnsafe o3;
|
||||
let s := MaxSharing.State.empty;
|
||||
let (o1, s) := s.maxSharing o1;
|
||||
let (o2, s) := s.maxSharing o2;
|
||||
let (o3, s) := s.maxSharing o3;
|
||||
check $ o1.x == 10;
|
||||
check $ o1.y == true;
|
||||
check $ o1.z == true;
|
||||
check $ o3.z == false;
|
||||
check $ ptrAddrUnsafe o1 == ptrAddrUnsafe o2;
|
||||
check $ ptrAddrUnsafe o1 != ptrAddrUnsafe o3;
|
||||
IO.println o1.x;
|
||||
pure ()
|
||||
|
||||
#eval tst3
|
||||
|
||||
unsafe def tst4 : IO Unit := do
|
||||
let x := ["hello"];
|
||||
let y := ["ello"].map (fun x => "h" ++ x);
|
||||
check $ ptrAddrUnsafe x != ptrAddrUnsafe y;
|
||||
let s := MaxSharing.State.empty;
|
||||
let (x, s) := s.maxSharing x;
|
||||
let (y, s) := s.maxSharing y;
|
||||
-- check $ ptrAddrUnsafe x == ptrAddrUnsafe y;
|
||||
let (z, s) := s.maxSharing ["world"];
|
||||
let (x, s) := s.maxSharing x;
|
||||
-- check $ ptrAddrUnsafe x == ptrAddrUnsafe y;
|
||||
check $ ptrAddrUnsafe x != ptrAddrUnsafe z;
|
||||
IO.println x;
|
||||
IO.println y;
|
||||
IO.println z
|
||||
|
||||
|
||||
#eval tst3
|
||||
Loading…
Add table
Reference in a new issue