feat: maxsharing for constructors

This commit is contained in:
Leonardo de Moura 2020-02-25 15:43:10 -08:00
parent d099599b19
commit 46e8d193ca
2 changed files with 214 additions and 74 deletions

View file

@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <vector>
#include "runtime/object.h"
#include "runtime/hash.h"
@ -151,94 +152,120 @@ public:
template<typename state>
class max_sharing_fn {
state m_state;
state m_state;
std::vector<lean_object*> m_children;
std::vector<lean_object*> m_todo;
obj_res visit_closure(b_obj_arg a) {
// TODO(Leo)
lean_inc(a);
return a;
void clear_children() {
m_children.clear();
}
obj_res visit_array(b_obj_arg a) {
// TODO(Leo)
lean_inc(a);
return a;
}
obj_res visit_sarray(b_obj_arg a) {
// TODO(Leo)
lean_inc(a);
return a;
}
obj_res visit_string(b_obj_arg a) {
// TODO(Leo)
lean_inc(a);
return a;
}
obj_res visit_mpz(b_obj_arg a) {
// TODO(Leo)
lean_inc(a);
return a;
}
obj_res visit_thunk(b_obj_arg a) {
// TODO(Leo)
lean_inc(a);
return a;
}
obj_res visit_ctor(b_obj_arg a) {
// TODO(Leo)
lean_inc(a);
return a;
}
obj_res visit(obj_arg a) {
bool push_child(b_obj_arg a) {
if (lean_is_scalar(a)) {
return a;
m_children.push_back(a);
return true;
}
switch (lean_ptr_tag(a)) {
case LeanReserved:
lean_unreachable();
// We do not maximize sharing for the following kinds of objects
case LeanMPZ: case LeanThunk:
case LeanTask: case LeanRef:
case LeanExternal:
m_children.push_back(a);
return true;
default:
break;
}
// Check whether we have already maximized sharing for `a`
obj_res o = m_state.map_find(a);
if (o != lean_box(0)) {
obj_res r = lean_ctor_get(o, 0);
lean_inc(r);
lean_dec(o);
std::cout << "found cached:" << r << "\n";
return r;
}
obj_res r;
switch (lean_ptr_tag(a)) {
case LeanClosure: r = visit_closure(a); break;
case LeanArray: r = visit_array(a); break;
case LeanScalarArray: r = visit_sarray(a); break;
case LeanString: r = visit_string(a); break;
case LeanMPZ: r = visit_mpz(a); break;
case LeanThunk: r = visit_thunk(a); break;
case LeanTask: return a;
case LeanRef: return a;
case LeanExternal: return a;
case LeanReserved: lean_unreachable();
default: r = visit_ctor(a); break;
// The map still has a reference to `r`
m_children.push_back(r);
// std::cout << "cached maximized " << r << "\n";
return true;
}
obj_res opt_new_r = m_state.set_find(r);
m_todo.push_back(a);
return false;
}
void save(b_obj_arg a, obj_arg new_a) {
lean_assert(m_todo.size() > 0);
lean_assert(m_todo.back() == a);
m_todo.pop_back();
obj_res opt_new_r = m_state.set_find(new_a);
if (opt_new_r != lean_box(0)) {
lean_dec(r); // we already have a maximally shared term equivalent to `r`
r = lean_ctor_get(opt_new_r, 0);
lean_inc_n(r, 2);
lean_dec(new_a); // we already have a maximally shared term equivalent to `new_a`
new_a = lean_ctor_get(opt_new_r, 0);
lean_inc(new_a);
lean_dec(opt_new_r);
m_state.map_insert(a, r);
std::cout << "found shared:" << r << "\n";
return r;
lean_inc(a);
m_state.map_insert(a, new_a);
// std::cout << "already maximized " << new_a << "\n";
} else {
lean_inc(a);
lean_inc_n(new_a, 3);
m_state.set_insert(new_a); // `new_a` is a new maximally shared term
m_state.map_insert(a, new_a); // `new_a` is the maximally shared representation for `a`
m_state.map_insert(new_a, new_a); // `new_a` is the maximally shared representation for itself
// std::cout << "new maximized " << new_a << "\n";
}
}
lean_inc_n(r, 4);
m_state.set_insert(r); // r is a new maximally shared term
m_state.map_insert(a, r); // `r` is the maximally shared representation for `a`
m_state.map_insert(r, r); // `r` is the maximally shared representation of itself
std::cout << "new shared:" << r << " " << lean_maxsharing_hash(r) << "\n";
return r;
void visit_closure(b_obj_arg a) {
// TODO(Leo)
lean_inc(a);
save(a, a);
}
void visit_array(b_obj_arg a) {
// TODO(Leo)
lean_inc(a);
save(a, a);
}
void visit_sarray(b_obj_arg a) {
// TODO(Leo)
lean_inc(a);
save(a, a);
}
void visit_string(b_obj_arg a) {
lean_inc(a);
save(a, a);
}
void visit_ctor(b_obj_arg a) {
clear_children();
unsigned num_objs = lean_ctor_num_objs(a);
bool missing_child = false;
for (unsigned i = 0; i < num_objs; i++) {
if (!push_child(lean_ctor_get(a, i))) {
// std::cout << "missing_child " << a << " #" << i << " := " << lean_ctor_get(a, i) << std::endl;
missing_child = true;
lean_assert(m_todo.back() == lean_ctor_get(a, i));
}
}
if (missing_child)
return;
unsigned tag = lean_ptr_tag(a);
unsigned sz = lean_object_byte_size(a);
lean_object * new_a = lean_alloc_small_object(sz);
lean_set_st_header(new_a, tag, num_objs);
for (unsigned i = 0; i < num_objs; i++) {
lean_inc(m_children[i]);
lean_ctor_set(new_a, i, m_children[i]);
}
unsigned scalar_offset = sizeof(lean_object) + num_objs*sizeof(void*);
if (scalar_offset < sz) {
unsigned scalar_sz = sz - scalar_offset;
memcpy(reinterpret_cast<char*>(new_a) + scalar_offset, reinterpret_cast<char*>(a) + scalar_offset, scalar_sz);
}
save(a, new_a);
}
public:
@ -246,7 +273,34 @@ public:
}
obj_res operator()(obj_arg a) {
return m_state.pack(visit(a));
if (push_child(a)) {
return m_state.pack(a);
}
while (!m_todo.empty()) {
b_obj_arg curr = m_todo.back();
// std::cout << "visiting " << curr << " " << static_cast<unsigned>(lean_ptr_tag(curr)) << "\n";
switch (lean_ptr_tag(curr)) {
case LeanClosure: visit_closure(curr); break;
case LeanArray: visit_array(curr); break;
case LeanScalarArray: visit_sarray(curr); break;
case LeanString: visit_string(curr); break;
case LeanMPZ: lean_unreachable();
case LeanThunk: lean_unreachable();
case LeanTask: lean_unreachable();
case LeanRef: lean_unreachable();
case LeanExternal: lean_unreachable();
case LeanReserved: lean_unreachable();
default: visit_ctor(curr); break;
}
}
obj_res o = m_state.map_find(a);
lean_assert(o != lean_box(0));
obj_res r = lean_ctor_get(o, 0);
lean_inc(r);
lean_dec(o);
lean_dec(a);
return m_state.pack(r);
}
};

View file

@ -0,0 +1,86 @@
def check (b : Bool) : IO Unit :=
unless b $ throw $ IO.userError "check failed"
unsafe def tst1 : IO Unit := do
let x := [1];
let y := [0].map (fun x => x + 1);
let s := MaxSharing.State.empty;
check $ ptrAddrUnsafe x != ptrAddrUnsafe y;
let (x, s) := s.maxSharing x;
let (y, s) := s.maxSharing y;
check $ ptrAddrUnsafe x == ptrAddrUnsafe y;
let (z, s) := s.maxSharing [2];
let (x, s) := s.maxSharing x;
check $ ptrAddrUnsafe x == ptrAddrUnsafe y;
check $ ptrAddrUnsafe x != ptrAddrUnsafe z;
IO.println x;
IO.println y;
IO.println z
#eval tst1
unsafe def tst2 : IO Unit := do
let x := [1, 2];
let y := [0, 1].map (fun x => x + 1);
check $ ptrAddrUnsafe x != ptrAddrUnsafe y;
let s := MaxSharing.State.empty;
let (x, s) := s.maxSharing x;
let (y, s) := s.maxSharing y;
check $ ptrAddrUnsafe x == ptrAddrUnsafe y;
let (z, s) := s.maxSharing [2];
let (x, s) := s.maxSharing x;
check $ ptrAddrUnsafe x == ptrAddrUnsafe y;
check $ ptrAddrUnsafe x != ptrAddrUnsafe z;
IO.println x;
IO.println y;
IO.println z
#eval tst2
structure Foo :=
(x : Nat)
(y : Bool)
(z : Bool)
@[noinline] def mkFoo1 (x : Nat) (z : Bool) : Foo := { x := x, y := true, z := z }
@[noinline] def mkFoo2 (x : Nat) (z : Bool) : Foo := { x := x, y := true, z := z }
unsafe def tst3 : IO Unit := do
let o1 := mkFoo1 10 true;
let o2 := mkFoo2 10 true;
let o3 := mkFoo2 10 false;
check $ ptrAddrUnsafe o1 != ptrAddrUnsafe o2;
check $ ptrAddrUnsafe o1 != ptrAddrUnsafe o3;
let s := MaxSharing.State.empty;
let (o1, s) := s.maxSharing o1;
let (o2, s) := s.maxSharing o2;
let (o3, s) := s.maxSharing o3;
check $ o1.x == 10;
check $ o1.y == true;
check $ o1.z == true;
check $ o3.z == false;
check $ ptrAddrUnsafe o1 == ptrAddrUnsafe o2;
check $ ptrAddrUnsafe o1 != ptrAddrUnsafe o3;
IO.println o1.x;
pure ()
#eval tst3
unsafe def tst4 : IO Unit := do
let x := ["hello"];
let y := ["ello"].map (fun x => "h" ++ x);
check $ ptrAddrUnsafe x != ptrAddrUnsafe y;
let s := MaxSharing.State.empty;
let (x, s) := s.maxSharing x;
let (y, s) := s.maxSharing y;
-- check $ ptrAddrUnsafe x == ptrAddrUnsafe y;
let (z, s) := s.maxSharing ["world"];
let (x, s) := s.maxSharing x;
-- check $ ptrAddrUnsafe x == ptrAddrUnsafe y;
check $ ptrAddrUnsafe x != ptrAddrUnsafe z;
IO.println x;
IO.println y;
IO.println z
#eval tst3