feat(library/init/data, runtime): remove parray support from runtime, and implement them in Lean using Scala/Clojure Radix trees

The Scala/Clojure approach for persistent arrays works great with our
`reset/reuse`. We seem to be much more efficient than their
implementations because of `reset/reuse`. The new approach also seems
better than the old one implemented in the runtime, and has a few
advantages:
1- The reroot procedure used in the old approach required
synchronization for multi-threaded code, or we would need to perform
deep copies when sending `parray` objects between threads.
2- We don't need any runtime extension for the new approach.
3- The old approach used "trail lists" for undoing array updates.
This works well for bactracking search use cases, but it is bad
in use cases where we are simultaneously updating the persistent
arrays that have shared nodes.
This commit is contained in:
Leonardo de Moura 2019-06-02 09:18:19 -07:00
parent 9d00a8d262
commit 30a6a2ade8
9 changed files with 211 additions and 641 deletions

View file

@ -0,0 +1,176 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import init.data.array
universes u v w
inductive PersistentArrayNode (α : Type u)
| node (cs : Array PersistentArrayNode) : PersistentArrayNode
| leaf (vs : Array α) : PersistentArrayNode
instance PersistentArrayNode.inhabited {α : Type u} : Inhabited (PersistentArrayNode α) :=
⟨PersistentArrayNode.leaf Array.empty⟩
abbrev PersistentArray.initShift : USize := 5
abbrev PersistentArray.branching : USize := USize.ofNat (2 ^ PersistentArray.initShift.toNat)
structure PersistentArray (α : Type u) :=
/- Recall that we run out of memory if we have more than `usizeSz/8` elements.
So, we can stop adding elements at `root` after `size > usizeSz`, and
keep growing the `tail`. This modification allow us to use `USize` instead
of `Nat` when traversing `root`. -/
(root : PersistentArrayNode α := PersistentArrayNode.node (Array.mkEmpty PersistentArray.branching.toNat))
(tail : Array α := Array.mkEmpty PersistentArray.branching.toNat)
(size : Nat := 0)
(shift : USize := PersistentArray.initShift)
(tailOff : Nat := 0)
abbrev PArray (α : Type u) := PersistentArray α
namespace PersistentArray
/- TODO: use proofs for showing that array accesses are not out of bounds.
We can do it after we reimplement the tactic framework. -/
variables {α : Type u} {β : Type v}
open PersistentArrayNode
instance : Inhabited (PersistentArray α) := ⟨{}⟩
def mkEmptyArray : Array α := Array.mkEmpty branching.toNat
abbrev mul2Shift (i : USize) (shift : USize) : USize := USize.shift_left i shift
abbrev div2Shift (i : USize) (shift : USize) : USize := USize.shift_right i shift
abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shift_left 1 shift) - 1)
partial def getAux [Inhabited α] : PersistentArrayNode α → USize → USize → α
| (node cs) i shift := getAux (cs.get (div2Shift i shift).toNat) (mod2Shift i shift) (shift - initShift)
| (leaf cs) i _ := cs.get i.toNat
def get [Inhabited α] (t : PersistentArray α) (i : Nat) : α :=
if i >= t.tailOff then
t.tail.get (i - t.tailOff)
else
getAux t.root (USize.ofNat i) t.shift
partial def setAux : PersistentArrayNode α → USize → USize → α → PersistentArrayNode α
| (node cs) i shift a :=
let j := div2Shift i shift in
let i := mod2Shift i shift in
let shift := shift - initShift in
node $ cs.modify j.toNat $ λ c, setAux c i shift a
| (leaf cs) i _ a := leaf (cs.set i.toNat a)
def set (t : PersistentArray α) (i : Nat) (a : α) : PersistentArray α :=
if i >= t.tailOff then
{ tail := t.tail.set (i - t.tailOff) a, .. t }
else
{ root := setAux t.root (USize.ofNat i) t.shift a, .. t }
partial def mkNewPath : USize → Array α → PersistentArrayNode α
| shift a :=
if shift == 0 then
leaf a
else
node (mkEmptyArray.push (mkNewPath (shift - initShift) a))
partial def insertNewLeaf : PersistentArrayNode α → USize → USize → Array α → PersistentArrayNode α
| (node cs) i shift a :=
if i < branching then
node (cs.push (leaf a))
else
let j := div2Shift i shift in
let i := mod2Shift i shift in
let shift := shift - initShift in
if j.toNat < cs.size then
node $ cs.modify j.toNat $ λ c, insertNewLeaf c i shift a
else
node $ cs.push $ mkNewPath shift a
| n _ _ _ := n -- unreachable
def mkNewTail (t : PersistentArray α) : PersistentArray α :=
if t.size <= (mul2Shift 1 (t.shift + initShift)).toNat then
{ tail := mkEmptyArray, root := insertNewLeaf t.root (USize.ofNat (t.size - 1)) t.shift t.tail,
tailOff := t.size,
.. t }
else
{ tail := Array.empty,
root := let n := mkEmptyArray.push t.root in
node (n.push (mkNewPath t.shift t.tail)),
shift := t.shift + initShift,
tailOff := t.size,
.. t }
def tooBig : Nat := usizeSz / 8
def push (t : PersistentArray α) (a : α) : PersistentArray α :=
let r := { tail := t.tail.push a, size := t.size + 1, .. t } in
if r.tail.size < branching.toNat || t.size >= tooBig then
r
else
mkNewTail r
section
variables {m : Type v → Type v} [Monad m]
local attribute [instance] monadInhabited'
@[specialize] partial def mfoldlAux (f : β → α → m β) : PersistentArrayNode α → β → m β
| (node cs) b := cs.mfoldl (λ b c, mfoldlAux c b) b
| (leaf vs) b := vs.mfoldl f b
@[specialize] def mfoldl (f : β → α → m β) (b : β) (t : PersistentArray α) : m β :=
do b ← mfoldlAux f t.root b, t.tail.mfoldl f b
end
@[inline] def foldl (f : β → α → β) (b : β) (t : PersistentArray α) : β :=
Id.run (t.mfoldl f b)
def toList (t : PersistentArray α) : List α :=
(t.foldl (λ xs x, x :: xs) []).reverse
section
variables {m : Type v → Type v} [Monad m]
@[specialize] partial def mmapAux (f : α → m β) : PersistentArrayNode α → m (PersistentArrayNode β)
| (node cs) := node <$> cs.mmap (λ c, mmapAux c)
| (leaf vs) := leaf <$> vs.mmap f
@[specialize] def mmap (f : α → m β) (t : PersistentArray α) : m (PersistentArray β) :=
do
root ← mmapAux f t.root,
tail ← t.tail.mmap f,
pure { tail := tail, root := root, .. t }
end
@[inline] def map (f : α → β) (t : PersistentArray α) : PersistentArray β :=
Id.run (t.mmap f)
structure Stats :=
(numNodes : Nat) (depth : Nat) (tailSize : Nat)
partial def collectStats : PersistentArrayNode α → Stats → Nat → Stats
| (node cs) s d :=
cs.foldl (λ s c, collectStats c s (d+1))
{ numNodes := s.numNodes + 1,
depth := Nat.max d s.depth, .. s }
| (leaf vs) s d := { numNodes := s.numNodes + 1, depth := Nat.max d s.depth, .. s }
def stats (r : PersistentArray α) : Stats :=
collectStats r.root { numNodes := 0, depth := 0, tailSize := r.tail.size } 0
def Stats.toString (s : Stats) : String :=
toString [s.numNodes, s.depth, s.tailSize]
instance : HasToString Stats := ⟨Stats.toString⟩
end PersistentArray
def List.toPersistentArrayAux {α : Type u} : List α → PersistentArray α → PersistentArray α
| [] t := t
| (x::xs) t := List.toPersistentArrayAux xs (t.push x)
def List.toPersistentArray {α : Type u} (xs : List α) : PersistentArray α :=
xs.toPersistentArrayAux {}

View file

@ -0,0 +1,7 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import init.data.persistentarray.basic

View file

@ -194,10 +194,6 @@ void object_compactor::operator()(object * o) {
case object_kind::Task: r = insert_task(curr); break;
case object_kind::Ref: r = insert_ref(curr); break;
case object_kind::External: throw exception("external objects cannot be compacted");
case object_kind::PArrayRoot:
case object_kind::PArraySet:
case object_kind::PArrayPush:
case object_kind::PArrayPop: throw exception("persistent array objects cannot be compacted");
}
if (r) m_todo.pop_back();
}
@ -315,10 +311,6 @@ object * compacted_region::read() {
case object_kind::Ref: fix_ref(curr); break;
case object_kind::Task: lean_unreachable();
case object_kind::External: lean_unreachable();
case object_kind::PArrayRoot: lean_unreachable();
case object_kind::PArraySet: lean_unreachable();
case object_kind::PArrayPush: lean_unreachable();
case object_kind::PArrayPop: lean_unreachable();
}
}
}

View file

@ -57,18 +57,6 @@ obj_res set_io_error(obj_arg r, std::string const & msg) {
return set_io_error(r, mk_io_user_error(mk_string(msg)));
}
static obj_res option_of_io_result(obj_arg r) {
if (io_result_is_ok(r)) {
object * o = alloc_cnstr(1, 1, 0);
cnstr_set(o, 0, io_result_get_value(r));
dec(r);
return o;
} else {
dec(r);
return box(0);
}
}
static bool g_initializing = true;
void io_mark_end_initialization() {
g_initializing = false;

View file

@ -56,10 +56,6 @@ size_t obj_byte_size(object * o) {
case object_kind::MPZ: return sizeof(mpz_object);
case object_kind::Thunk: return sizeof(thunk_object);
case object_kind::Task: return sizeof(task_object);
case object_kind::PArrayRoot: return sizeof(parray_object);
case object_kind::PArraySet: return sizeof(parray_object);
case object_kind::PArrayPush: return sizeof(parray_object);
case object_kind::PArrayPop: return sizeof(parray_object);
case object_kind::Ref: return sizeof(ref_object);
case object_kind::External: lean_unreachable();
}
@ -76,10 +72,6 @@ size_t obj_header_size(object * o) {
case object_kind::MPZ: return sizeof(mpz_object);
case object_kind::Thunk: return sizeof(thunk_object);
case object_kind::Task: return sizeof(task_object);
case object_kind::PArrayRoot: return sizeof(parray_object);
case object_kind::PArraySet: return sizeof(parray_object);
case object_kind::PArrayPush: return sizeof(parray_object);
case object_kind::PArrayPop: return sizeof(parray_object);
case object_kind::Ref: return sizeof(ref_object);
case object_kind::External: lean_unreachable();
}
@ -126,23 +118,6 @@ inline void dec(object * o, object* & todo) {
void deactivate_task(task_object * t);
static size_t parray_data_capacity(object ** data) {
return reinterpret_cast<size_t*>(data)[-1];
}
static object ** alloc_parray_data(size_t c) {
size_t * mem = static_cast<size_t*>(malloc(sizeof(object*)*c + sizeof(size_t)));
*mem = c;
mem++;
return reinterpret_cast<object**>(mem);
}
static void dealloc_parray_data(object ** data) {
size_t * mem = reinterpret_cast<size_t*>(data);
mem--;
free(mem);
}
#ifdef LEAN_SMALL_ALLOCATOR
static inline void free_heap_obj_core(object * o, size_t sz) {
#else
@ -214,10 +189,6 @@ static inline void free_task_obj(object * o) {
FREE_OBJ(o, sizeof(task_object) + sizeof(rc_type));
}
static inline void free_parray_obj(object * o) {
FREE_OBJ(o, sizeof(parray_object) + sizeof(rc_type));
}
static inline void free_external_obj(object * o) {
FREE_OBJ(o, sizeof(external_object) + sizeof(rc_type));
}
@ -265,24 +236,6 @@ static void del_core(object * o, object * & todo) {
case object_kind::Task:
deactivate_task(to_task(o));
break;
case object_kind::PArrayPop:
dec_ref(to_parray(o)->m_next, todo);
free_parray_obj(o);
break;
case object_kind::PArrayPush:
case object_kind::PArraySet:
dec(to_parray(o)->m_elem, todo);
dec_ref(to_parray(o)->m_next, todo);
free_parray_obj(o);
break;
case object_kind::PArrayRoot: {
object ** it = to_parray(o)->m_data;
object ** end = it + to_parray(o)->m_size;
for (; it != end; ++it) dec(*it, todo);
dealloc_parray_data(to_parray(o)->m_data);
free_parray_obj(o);
break;
}
case object_kind::External:
to_external(o)->m_class->m_finalize(to_external(o)->m_data);
free_external_obj(o);
@ -330,184 +283,6 @@ object * array_mk_empty() {
return g_array_empty;
}
// =======================================
// Persistent arrays
static object ** parray_data_expand(object ** data, size_t sz) {
size_t curr_capacity = parray_data_capacity(data);
size_t new_capacity = curr_capacity == 0 ? 2 : (3 * curr_capacity + 1) >> 1;
object ** new_data = alloc_parray_data(new_capacity);
memcpy(new_data, data, sizeof(object*)*sz);
dealloc_parray_data(data);
return new_data;
}
/* Given `c -> ... -> root`,
revert links and make `c` to be the new root:
`c <- ... <- root` */
static void parray_reroot(object * c) {
lean_assert(get_kind(c) != object_kind::PArrayRoot);
parray_object * it = to_parray(c);
parray_object * prev = nullptr;
/* invert links */
while (get_kind(it) != object_kind::PArrayRoot) {
/* c <- ... <- prev, it -> it_next -> ... -> root
c <- ... <- prev <- it, it_next -> ... -> root */
parray_object * it_next = it->m_next;
it->m_next = prev;
prev = it;
it = it_next;
}
lean_assert(prev != nullptr);
lean_assert(get_kind(it) == object_kind::PArrayRoot);
lean_assert(it != c);
object * old_root = it;
it->m_next = prev;
object ** data = it->m_data;
size_t sz = it->m_size;
prev = it;
it = it->m_next;
/* move array to `c` */
while (true) {
lean_assert(prev != nullptr && it != prev);
switch (get_kind(it)) {
case object_kind::PArraySet:
prev->m_kind = static_cast<unsigned>(object_kind::PArraySet);
prev->m_idx = it->m_idx;
prev->m_elem = data[it->m_idx];
data[it->m_idx] = it->m_elem;
break;
case object_kind::PArrayPush:
if (sz == parray_data_capacity(data))
data = parray_data_expand(data, sz);
prev->m_kind = static_cast<unsigned>(object_kind::PArrayPop);
data[sz] = it->m_elem;
sz++;
break;
case object_kind::PArrayPop:
--sz;
prev->m_kind = static_cast<unsigned>(object_kind::PArrayPush);
prev->m_elem = data[sz];
break;
default:
lean_unreachable();
}
if (it == c)
break;
prev = it;
it = it->m_next;
}
lean_assert(it == c);
it->m_kind = static_cast<unsigned>(object_kind::PArrayRoot);
it->m_data = data;
it->m_size = sz;
dec_ref(old_root);
inc_ref(c);
}
static parray_object * move_parray_root(parray_object * src) {
lean_assert(src->m_kind == static_cast<unsigned>(object_kind::PArrayRoot));
lean_assert(get_rc(src) > 1);
dec_ref(src);
parray_object * r = new (alloc_heap_object(sizeof(parray_object))) parray_object();
r->m_data = src->m_data;
r->m_size = src->m_size;
return r;
}
obj_res alloc_parray(size_t capacity) {
parray_object * r = new (alloc_heap_object(sizeof(parray_object))) parray_object();
r->m_data = alloc_parray_data(capacity);
r->m_size = 0;
return r;
}
size_t parray_size(b_obj_arg o) {
if (get_kind(o) != object_kind::PArrayRoot)
parray_reroot(o);
return to_parray(o)->m_size;
}
b_obj_res parray_get(b_obj_arg o, size_t i) {
if (get_kind(o) != object_kind::PArrayRoot)
parray_reroot(o);
return to_parray(o)->m_data[i];
}
obj_res parray_set(obj_arg o, size_t i, obj_arg v) {
if (get_kind(o) != object_kind::PArrayRoot)
parray_reroot(o);
parray_object * p = to_parray(o);
if (get_rc(p) == 1) {
dec(p->m_data[i]);
p->m_data[i] = v;
return p;
} else {
parray_object * r = move_parray_root(p);
p->m_kind = static_cast<unsigned>(object_kind::PArraySet);
p->m_idx = i;
p->m_elem = r->m_data[i];
p->m_next = r;
inc_ref(r);
r->m_data[i] = v;
return r;
}
}
obj_res parray_push(obj_arg o, obj_arg v) {
if (get_kind(o) != object_kind::PArrayRoot)
parray_reroot(o);
parray_object * p = to_parray(o);
if (p->m_size == parray_data_capacity(p->m_data))
p->m_data = parray_data_expand(p->m_data, p->m_size);
if (get_rc(p) == 1) {
p->m_data[p->m_size] = v;
p->m_size++;
return p;
} else {
parray_object * r = move_parray_root(p);
p->m_kind = static_cast<unsigned>(object_kind::PArrayPop);
p->m_next = r;
inc_ref(r);
r->m_data[r->m_size] = v;
r->m_size++;
return r;
}
}
obj_res parray_pop(obj_arg o) {
if (get_kind(o) != object_kind::PArrayRoot)
parray_reroot(o);
parray_object * p = to_parray(o);
if (get_rc(p) == 1) {
p->m_size--;
dec(p->m_data[p->m_size]);
return p;
} else {
parray_object * r = move_parray_root(p);
r->m_size--;
p->m_kind = static_cast<unsigned>(object_kind::PArrayPush);
p->m_elem = r->m_data[r->m_size];
p->m_next = r;
inc_ref(r);
return r;
}
}
obj_res parray_copy(b_obj_arg o) {
if (get_kind(o) != object_kind::PArrayRoot)
parray_reroot(o);
size_t sz = to_parray(o)->m_size;
object ** data = to_parray(o)->m_data;
parray_object * r = new (alloc_heap_object(sizeof(parray_object))) parray_object();
r->m_size = sz;
r->m_data = alloc_parray_data(parray_data_capacity(data));
memcpy(r->m_data, data, sz);
for (size_t i = 0; i < sz; i++)
inc(data[i]);
return r;
}
// =======================================
// Closures
@ -878,13 +653,6 @@ void mark_mt(object * o) {
case object_kind::String:
case object_kind::MPZ:
return;
case object_kind::PArrayPop:
case object_kind::PArrayPush:
case object_kind::PArraySet:
case object_kind::PArrayRoot:
/* `mark_mt` cannot be used with parray. They must be copied when used in multiple threads. */
lean_unreachable();
return;
case object_kind::External: {
object * fn = alloc_closure(reinterpret_cast<void*>(mark_mt_fn), 1, 0);
to_external(o)->m_class->m_foreach(to_external(o)->m_data, fn);
@ -1082,20 +850,6 @@ void mark_persistent(object * o) {
case object_kind::String:
case object_kind::MPZ:
return;
case object_kind::PArrayPop:
mark_persistent(to_parray(o)->m_next);
return;
case object_kind::PArrayPush:
case object_kind::PArraySet:
mark_persistent(to_parray(o)->m_elem);
mark_persistent(to_parray(o)->m_next);
return;
case object_kind::PArrayRoot: {
object ** it = to_parray(o)->m_data;
object ** end = it + to_parray(o)->m_size;
for (; it != end; ++it) mark_persistent(*it);
return;
}
case object_kind::External: {
object * fn = alloc_closure(reinterpret_cast<void*>(mark_persistent_fn), 1, 0);
to_external(o)->m_class->m_foreach(to_external(o)->m_data, fn);

View file

@ -64,7 +64,6 @@ extern atomic<uint64> g_num_del;
enum class object_memory_kind { MTHeap = 0, STHeap, Persistent, Stack, Region };
enum class object_kind { Constructor, Closure, Array, ScalarArray,
PArrayRoot, PArraySet, PArrayPush, PArrayPop,
String, MPZ, Thunk, Task, Ref, External };
/* Objects are initially allocated as STHeap. When we create a task, we change it to MTHeap. */
@ -162,23 +161,6 @@ struct string_object : public object {
object(object_kind::String, m), m_size(sz), m_capacity(c), m_length(len) {}
};
/* Persistent arrays are implemented using 4 different kinds of cell:
PArraySet, PArrayPush, PArrayPop and PArrayRoot. */
struct parray_object : public object {
parray_object * m_next; // PArraySet, PArrayPush, PArrayPop
union {
size_t m_idx; // PArraySet
size_t m_size; // PArrayRoot
};
union {
object ** m_data; // PArrayRoot
object * m_elem; // PArrayPush and PArraySet
};
/* Remark: persistent arrays are single threaded object. The `mark_shared` operation
copies it when the RC > 1 */
parray_object():object(object_kind::PArrayRoot, object_memory_kind::STHeap) {}
};
/* Note that `m_fun` is a pointer to a C function.
The `apply` operator performs a cast operation. It casts m_fun to a C function pointer of the right arity.
@ -383,7 +365,6 @@ inline bool is_cnstr(object * o) { return get_kind(o) == object_kind::Constructo
inline bool is_closure(object * o) { return get_kind(o) == object_kind::Closure; }
inline bool is_array(object * o) { return get_kind(o) == object_kind::Array; }
inline bool is_sarray(object * o) { return get_kind(o) == object_kind::ScalarArray; }
inline bool is_parray(object * o) { auto k = get_kind(o); return k == object_kind::PArrayRoot || k == object_kind::PArraySet || k == object_kind::PArrayPush || k == object_kind::PArrayPop; }
inline bool is_string(object * o) { return get_kind(o) == object_kind::String; }
inline bool is_mpz(object * o) { return get_kind(o) == object_kind::MPZ; }
inline bool is_thunk(object * o) { return get_kind(o) == object_kind::Thunk; }
@ -398,7 +379,6 @@ inline constructor_object * to_cnstr(object * o) { lean_assert(is_cnstr(o)); ret
inline closure_object * to_closure(object * o) { lean_assert(is_closure(o)); return static_cast<closure_object*>(o); }
inline array_object * to_array(object * o) { lean_assert(is_array(o)); return static_cast<array_object*>(o); }
inline sarray_object * to_sarray(object * o) { lean_assert(is_sarray(o)); return static_cast<sarray_object*>(o); }
inline parray_object * to_parray(object * o) { lean_assert(is_parray(o)); return static_cast<parray_object*>(o); }
inline string_object * to_string(object * o) { lean_assert(is_string(o)); return static_cast<string_object*>(o); }
inline mpz_object * to_mpz(object * o) { lean_assert(is_mpz(o)); return static_cast<mpz_object*>(o); }
inline thunk_object * to_thunk(object * o) { lean_assert(is_thunk(o)); return static_cast<thunk_object*>(o); }
@ -700,19 +680,6 @@ inline void array_set(u_obj_arg o, size_t i, obj_arg v) {
obj_set_data(o, sizeof(array_object) + sizeof(object*)*i, v); // NOLINT
}
// =======================================
// Persistent Array of objects
obj_res alloc_parray(size_t capacity);
size_t parray_size(b_obj_arg o);
b_obj_res parray_get(b_obj_arg o, size_t i);
obj_res parray_set(obj_arg o, size_t i, obj_arg v);
obj_res parray_push(obj_arg o, obj_arg v);
obj_res parray_pop(obj_arg o);
obj_res parray_copy(b_obj_arg o);
// =======================================
// =======================================
// Array of scalars

View file

@ -389,45 +389,6 @@ void tst13() {
std::cout << g_task7_counter << "\n";
}
obj_res mk_parray(unsigned n, b_obj_arg v) {
object * r = alloc_parray(n);
for (unsigned i = 0; i < n; i++) {
inc(v);
r = parray_push(r, v);
}
return r;
}
void tst14() {
object * a = mk_parray(10, box(0));
lean_assert(parray_size(a) == 10);
lean_assert(parray_get(a, 0) == box(0));
object * b = a;
inc(b);
lean_assert(get_rc(a) == 2);
lean_assert(get_rc(b) == 2);
a = parray_set(a, 0, box(1));
lean_assert(a != b);
lean_assert(get_rc(b) == 1);
lean_assert(get_rc(a) == 2);
lean_assert(parray_get(a, 0) == box(1));
lean_assert(parray_get(a, 1) == box(0));
lean_assert(parray_get(b, 0) == box(0));
lean_assert(parray_get(a, 0) == box(1));
inc(b);
object * c = b;
c = parray_push(c, box(20));
lean_assert(parray_size(c) == 11);
lean_assert(parray_size(a) == 10);
lean_assert(parray_size(b) == 10);
lean_assert(parray_get(c, 0) == box(0));
lean_assert(parray_get(c, 10) == box(20));
lean_assert(parray_get(a, 0) == box(1));
lean_assert(parray_get(b, 0) == box(0));
dec_ref(a); dec_ref(b);
lean_assert(get_rc(c) == 1);
dec_ref(c);
}
obj_res mk_foo(unsigned n) {
object * r = alloc_cnstr(0, 1, 0);
@ -439,106 +400,6 @@ unsigned foo_val(b_obj_arg v) {
return unbox(cnstr_get(v, 0));
}
void tst15() {
object * v1 = alloc_parray(0);
v1 = parray_push(v1, mk_foo(2));
v1 = parray_push(v1, mk_foo(3));
lean_assert(foo_val(parray_get(v1, 0)) == 2);
lean_assert(foo_val(parray_get(v1, 1)) == 3);
object * v2 = v1;
inc(v2);
for (unsigned i = 0; i < 10; i++)
v1 = parray_push(v1, mk_foo(i));
v1 = parray_set(v1, 0, mk_foo(100));
v1 = parray_set(v1, 1, mk_foo(100));
lean_assert(parray_size(v2) == 2);
lean_assert(foo_val(parray_get(v2, 0)) == 2);
lean_assert(foo_val(parray_get(v2, 1)) == 3);
object * v3 = v1;
inc(v3);
v1 = parray_pop(v1);
v1 = parray_pop(v1);
lean_assert(parray_size(v1) == 10);
lean_assert(parray_size(v3) == 12);
dec_ref(v1); dec_ref(v2);
lean_assert(get_rc(v3) == 1);
dec_ref(v3);
}
void driver(unsigned max_sz, unsigned max_val, unsigned num_it,
double push_freq,
double pop_freq,
double set_freq,
double copy_freq) {
object * v1 = alloc_parray(0);
std::vector<unsigned> v2;
std::mt19937 rng;
rng.seed(static_cast<unsigned int>(time(0)));
std::uniform_int_distribution<unsigned int> uint_dist;
std::vector<std::pair<object*, std::vector<unsigned>>> copies;
lean_assert(get_rc(v1) == 1);
size_t acc_sz = 0;
for (unsigned i = 0; i < num_it; i++) {
acc_sz += parray_size(v1);
double f = static_cast<double>(uint_dist(rng) % 10000) / 10000.0;
if (f < copy_freq) {
inc_ref(v1);
copies.emplace_back(v1, v2);
}
f = static_cast<double>(uint_dist(rng) % 10000) / 10000.0;
if (f < push_freq) {
if (parray_size(v1) < max_sz) {
unsigned a = uint_dist(rng) % max_val;
v1 = parray_push(v1, box(a));
v2.push_back(a);
}
}
if (parray_size(v1) > 0) {
f = static_cast<double>(uint_dist(rng) % 10000) / 10000.0;
if (f < pop_freq) {
v1 = parray_pop(v1);
v2.pop_back();
}
}
if (parray_size(v1) > 0) {
f = static_cast<double>(uint_dist(rng) % 10000) / 10000.0;
if (f < set_freq) {
unsigned idx = uint_dist(rng) % parray_size(v1);
unsigned a = uint_dist(rng) % max_val;
v1 = parray_set(v1, idx, box(a));
v2[idx] = a;
}
}
f = static_cast<double>(uint_dist(rng) % 10000) / 10000.0;
lean_assert(parray_size(v1) == v2.size());
for (unsigned i = 0; i < v2.size(); i++) {
lean_assert(unbox(parray_get(v1, i)) == v2[i]);
}
}
for (std::pair<object *, std::vector<unsigned>> const & p : copies) {
lean_assert(parray_size(p.first) == p.second.size());
for (unsigned i = 0; i < p.second.size(); i++) {
lean_assert(unbox(parray_get(p.first, i)) == p.second[i]);
}
dec_ref(p.first);
}
std::cout << "\n";
std::cout << "Copies created: " << copies.size() << "\n";
std::cout << "Average size: " << static_cast<double>(acc_sz) / static_cast<double>(num_it) << "\n";
lean_assert(get_rc(v1) == 1);
dec_ref(v1);
}
void tst16() {
driver(4, 32, 10000, 0.5, 0.1, 0.5, 0.1);
driver(4, 32, 10000, 0.5, 0.1, 0.5, 0.1);
driver(4, 32, 10000, 0.5, 0.1, 0.5, 0.5);
driver(16, 16, 100000, 0.5, 0.5, 0.5, 0.01);
driver(16, 16, 100000, 0.5, 0.1, 0.5, 0.01);
driver(16, 16, 100000, 0.5, 0.6, 0.5, 0.01);
driver(16, 16, 10000, 0.5, 0.1, 0.5, 0.0);
}
object * mk_list(unsigned n) {
object * r = box(0);
for (unsigned i = 0; i < n; i++) {
@ -878,9 +739,6 @@ int main() {
tst11();
tst12();
tst13();
tst14();
tst15();
tst16();
// tst17(40000, 3000);
tst17(400, 30);
// tst18(4000, 3000);

View file

@ -0,0 +1,28 @@
import init.data.persistentarray
abbrev MyArray := PersistentArray Nat
-- abbrev MyArray := Array Nat
def mkMyArray (n : Nat) : MyArray :=
n.fold (λ i s, s.push i) { PersistentArray . }
-- n.fold (λ i s, s.push i) Array.empty
def check (n : Nat) (p : Nat → Nat → Bool) (s : MyArray) : IO Unit :=
n.mfor $ λ i, unless (p i (s.get i)) (throw (IO.userError ("failed at " ++ toString i ++ " " ++ toString (s.get i))))
def inc1 (n : Nat) (s : MyArray) : MyArray :=
n.fold (λ i s, s.set i (s.get i + 1)) s
def checkId (n : Nat) (s : MyArray) : IO Unit :=
check n (==) s
def main (xs : List String) : IO Unit :=
do
let n := xs.head.toNat,
let t := mkMyArray n,
checkId n t,
let t := inc1 n t,
check n (λ i v, v == i + 1) t,
IO.println t.size,
IO.println t.stats,
pure ()

View file

@ -1,200 +0,0 @@
import init.lean.format
open Lean
universes u v w
inductive RadixNode (α : Type u)
| node (cs : Array RadixNode) : RadixNode
| leaf (vs : Array α) : RadixNode
instance RadixNode.inhabited {α : Type u} : Inhabited (RadixNode α) :=
⟨RadixNode.leaf Array.empty⟩
abbrev RadixTree.initShift : USize := 5
abbrev RadixTree.branching : USize := USize.ofNat (2 ^ RadixTree.initShift.toNat)
structure RadixTree (α : Type u) :=
/- Recall that we run out of memory if we have more than `usizeSz/8` elements.
So, we can stop adding elements at `root` after `size > usizeSz`, and
keep growing the `tail`. This modification allow us to use `USize` instead
of `Nat` when traversing `root`. -/
(root : RadixNode α := RadixNode.node (Array.mkEmpty RadixTree.branching.toNat))
(tail : Array α := Array.mkEmpty RadixTree.branching.toNat)
(size : Nat := 0)
(shift : USize := RadixTree.initShift)
(tailOff : Nat := 0)
namespace RadixTree
/- TODO:
- Use proofs for showing that array accesses are not out of bounds.
-/
variables {α : Type u} {β : Type v}
open RadixNode
instance : Inhabited (RadixTree α) := ⟨{}⟩
def mkEmptyArray : Array α := Array.mkEmpty branching.toNat
abbrev mul2Shift (i : USize) (shift : USize) : USize := USize.shift_left i shift
abbrev div2Shift (i : USize) (shift : USize) : USize := USize.shift_right i shift
abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shift_left 1 shift) - 1)
partial def getAux [Inhabited α] : RadixNode α → USize → USize → α
| (node cs) i shift := getAux (cs.get (div2Shift i shift).toNat) (mod2Shift i shift) (shift - initShift)
| (leaf cs) i _ := cs.get i.toNat
def get [Inhabited α] (t : RadixTree α) (i : Nat) : α :=
if i >= t.tailOff then
t.tail.get (i - t.tailOff)
else
getAux t.root (USize.ofNat i) t.shift
partial def setAux : RadixNode α → USize → USize → α → RadixNode α
| (node cs) i shift a := node (cs.modify (div2Shift i shift).toNat $ λ c,
setAux c (mod2Shift i shift) (shift - initShift) a)
| (leaf cs) i _ a := leaf (cs.set i.toNat a)
def set (t : RadixTree α) (i : Nat) (a : α) : RadixTree α :=
if i >= t.tailOff then
{ tail := t.tail.set (i - t.tailOff) a, .. t }
else
{ root := setAux t.root (USize.ofNat i) t.shift a, .. t }
partial def mkNewPath : USize → Array α → RadixNode α
| shift a :=
if shift == 0 then
leaf a
else
node (mkEmptyArray.push (mkNewPath (shift - initShift) a))
partial def insertNewLeaf : RadixNode α → USize → USize → Array α → RadixNode α
| (node cs) i shift a :=
if i < branching then
node (cs.push (leaf a))
else
let j := div2Shift i shift in
if j.toNat < cs.size then
node (cs.modify j.toNat $ λ c, insertNewLeaf c (mod2Shift i shift) (shift - initShift) a)
else
node (cs.push (mkNewPath (shift - initShift) a))
| n _ _ _ := n -- unreachable
def mkNewTail (t : RadixTree α) : RadixTree α :=
if t.size <= (mul2Shift 1 (t.shift + initShift)).toNat then
{ tail := mkEmptyArray, root := insertNewLeaf t.root (USize.ofNat (t.size - 1)) t.shift t.tail,
tailOff := t.size,
.. t }
else
{ tail := Array.empty,
root := let n := mkEmptyArray.push t.root in
node (n.push (mkNewPath t.shift t.tail)),
shift := t.shift + initShift,
tailOff := t.size,
.. t }
def tooBig : Nat := usizeSz / 8
def push (t : RadixTree α) (a : α) : RadixTree α :=
let r := { tail := t.tail.push a, size := t.size + 1, .. t } in
if r.tail.size < branching.toNat || t.size >= tooBig then
r
else
mkNewTail r
section
variables {m : Type v → Type v} [Monad m]
local attribute [instance] monadInhabited'
@[specialize] partial def mfoldlAux (f : β → α → m β) : RadixNode α → β → m β
| (node cs) b := cs.mfoldl (λ b c, mfoldlAux c b) b
| (leaf vs) b := vs.mfoldl f b
@[specialize] def mfoldl (f : β → α → m β) (b : β) (t : RadixTree α) : m β :=
do b ← mfoldlAux f t.root b, t.tail.mfoldl f b
end
@[inline] def foldl (f : β → α → β) (b : β) (t : RadixTree α) : β :=
Id.run (t.mfoldl f b)
def toList (t : RadixTree α) : List α :=
(t.foldl (λ xs x, x :: xs) []).reverse
section
variables {m : Type v → Type v} [Monad m]
@[specialize] partial def mmapAux (f : α → m β) : RadixNode α → m (RadixNode β)
| (node cs) := node <$> cs.mmap (λ c, mmapAux c)
| (leaf vs) := leaf <$> vs.mmap f
@[specialize] def mmap (f : α → m β) (t : RadixTree α) : m (RadixTree β) :=
do
root ← mmapAux f t.root,
tail ← t.tail.mmap f,
pure { tail := tail, root := root, .. t }
end
@[inline] def map (f : α → β) (t : RadixTree α) : RadixTree β :=
Id.run (t.mmap f)
structure Stats :=
(numNodes : Nat) (depth : Nat) (tailSize : Nat)
partial def collectStats : RadixNode α → Stats → Nat → Stats
| (node cs) s d :=
cs.foldl (λ s c, collectStats c s (d+1))
{ numNodes := s.numNodes + 1,
depth := Nat.max d s.depth, .. s }
| (leaf vs) s d := { numNodes := s.numNodes + 1, depth := Nat.max d s.depth, .. s }
def stats (r : RadixTree α) : Stats :=
collectStats r.root { numNodes := 0, depth := 0, tailSize := r.tail.size } 0
def Stats.toString (s : Stats) : String :=
toString [s.numNodes, s.depth, s.tailSize]
instance : HasToString Stats := ⟨Stats.toString⟩
partial def formatRawAux [HasFormat α] : RadixNode α → Format
| (node cs) := "Node" ++ Format.sbracket (cs.foldl (λ f c, f ++ Format.line ++ formatRawAux c) Format.nil)
| (leaf cs) := format cs.toList
partial def formatRaw [HasFormat α] (t : RadixTree α) : Format :=
Format.bracket "{" ("root :=" ++ Format.line ++ formatRawAux t.root ++ "," ++ Format.line ++
"tail :=" ++ Format.line ++ format t.tail.toList) "}"
end RadixTree
def List.toRadixTreeAux {α : Type u} : List α → RadixTree α → RadixTree α
| [] t := t
| (x::xs) t := List.toRadixTreeAux xs (t.push x)
def List.toRadixTree {α : Type u} (xs : List α) : RadixTree α :=
xs.toRadixTreeAux {}
abbrev PArray := RadixTree Nat
-- abbrev PArray := Array Nat
def mkRadixTree (n : Nat) : PArray :=
n.fold (λ i s, s.push i) { RadixTree . }
-- n.fold (λ i s, s.push i) Array.empty
def check (n : Nat) (p : Nat → Nat → Bool) (s : PArray) : IO Unit :=
n.mfor $ λ i, unless (p i (s.get i)) (throw (IO.userError ("failed at " ++ toString i ++ " " ++ toString (s.get i))))
def inc1 (n : Nat) (s : PArray) : PArray :=
n.fold (λ i s, s.set i (s.get i + 1)) s
def checkId (n : Nat) (s : PArray) : IO Unit :=
check n (==) s
def main (xs : List String) : IO Unit :=
do
let n := xs.head.toNat,
let t := mkRadixTree n,
-- IO.println t.formatRaw *>
checkId n t,
let t := inc1 n t,
check n (λ i v, v == i + 1) t,
IO.println t.size,
IO.println t.stats,
pure ()