feat(library/init/data, runtime): remove parray support from runtime, and implement them in Lean using Scala/Clojure Radix trees
The Scala/Clojure approach for persistent arrays works great with our `reset/reuse`. We seem to be much more efficient than their implementations because of `reset/reuse`. The new approach also seems better than the old one implemented in the runtime, and has a few advantages: 1- The reroot procedure used in the old approach required synchronization for multi-threaded code, or we would need to perform deep copies when sending `parray` objects between threads. 2- We don't need any runtime extension for the new approach. 3- The old approach used "trail lists" for undoing array updates. This works well for bactracking search use cases, but it is bad in use cases where we are simultaneously updating the persistent arrays that have shared nodes.
This commit is contained in:
parent
9d00a8d262
commit
30a6a2ade8
9 changed files with 211 additions and 641 deletions
176
library/init/data/persistentarray/basic.lean
Normal file
176
library/init/data/persistentarray/basic.lean
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
/-
|
||||
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import init.data.array
|
||||
universes u v w
|
||||
|
||||
inductive PersistentArrayNode (α : Type u)
|
||||
| node (cs : Array PersistentArrayNode) : PersistentArrayNode
|
||||
| leaf (vs : Array α) : PersistentArrayNode
|
||||
|
||||
instance PersistentArrayNode.inhabited {α : Type u} : Inhabited (PersistentArrayNode α) :=
|
||||
⟨PersistentArrayNode.leaf Array.empty⟩
|
||||
|
||||
abbrev PersistentArray.initShift : USize := 5
|
||||
abbrev PersistentArray.branching : USize := USize.ofNat (2 ^ PersistentArray.initShift.toNat)
|
||||
|
||||
structure PersistentArray (α : Type u) :=
|
||||
/- Recall that we run out of memory if we have more than `usizeSz/8` elements.
|
||||
So, we can stop adding elements at `root` after `size > usizeSz`, and
|
||||
keep growing the `tail`. This modification allow us to use `USize` instead
|
||||
of `Nat` when traversing `root`. -/
|
||||
(root : PersistentArrayNode α := PersistentArrayNode.node (Array.mkEmpty PersistentArray.branching.toNat))
|
||||
(tail : Array α := Array.mkEmpty PersistentArray.branching.toNat)
|
||||
(size : Nat := 0)
|
||||
(shift : USize := PersistentArray.initShift)
|
||||
(tailOff : Nat := 0)
|
||||
|
||||
abbrev PArray (α : Type u) := PersistentArray α
|
||||
|
||||
namespace PersistentArray
|
||||
/- TODO: use proofs for showing that array accesses are not out of bounds.
|
||||
We can do it after we reimplement the tactic framework. -/
|
||||
variables {α : Type u} {β : Type v}
|
||||
open PersistentArrayNode
|
||||
|
||||
instance : Inhabited (PersistentArray α) := ⟨{}⟩
|
||||
|
||||
def mkEmptyArray : Array α := Array.mkEmpty branching.toNat
|
||||
|
||||
abbrev mul2Shift (i : USize) (shift : USize) : USize := USize.shift_left i shift
|
||||
abbrev div2Shift (i : USize) (shift : USize) : USize := USize.shift_right i shift
|
||||
abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shift_left 1 shift) - 1)
|
||||
|
||||
partial def getAux [Inhabited α] : PersistentArrayNode α → USize → USize → α
|
||||
| (node cs) i shift := getAux (cs.get (div2Shift i shift).toNat) (mod2Shift i shift) (shift - initShift)
|
||||
| (leaf cs) i _ := cs.get i.toNat
|
||||
|
||||
def get [Inhabited α] (t : PersistentArray α) (i : Nat) : α :=
|
||||
if i >= t.tailOff then
|
||||
t.tail.get (i - t.tailOff)
|
||||
else
|
||||
getAux t.root (USize.ofNat i) t.shift
|
||||
|
||||
partial def setAux : PersistentArrayNode α → USize → USize → α → PersistentArrayNode α
|
||||
| (node cs) i shift a :=
|
||||
let j := div2Shift i shift in
|
||||
let i := mod2Shift i shift in
|
||||
let shift := shift - initShift in
|
||||
node $ cs.modify j.toNat $ λ c, setAux c i shift a
|
||||
| (leaf cs) i _ a := leaf (cs.set i.toNat a)
|
||||
|
||||
def set (t : PersistentArray α) (i : Nat) (a : α) : PersistentArray α :=
|
||||
if i >= t.tailOff then
|
||||
{ tail := t.tail.set (i - t.tailOff) a, .. t }
|
||||
else
|
||||
{ root := setAux t.root (USize.ofNat i) t.shift a, .. t }
|
||||
|
||||
partial def mkNewPath : USize → Array α → PersistentArrayNode α
|
||||
| shift a :=
|
||||
if shift == 0 then
|
||||
leaf a
|
||||
else
|
||||
node (mkEmptyArray.push (mkNewPath (shift - initShift) a))
|
||||
|
||||
partial def insertNewLeaf : PersistentArrayNode α → USize → USize → Array α → PersistentArrayNode α
|
||||
| (node cs) i shift a :=
|
||||
if i < branching then
|
||||
node (cs.push (leaf a))
|
||||
else
|
||||
let j := div2Shift i shift in
|
||||
let i := mod2Shift i shift in
|
||||
let shift := shift - initShift in
|
||||
if j.toNat < cs.size then
|
||||
node $ cs.modify j.toNat $ λ c, insertNewLeaf c i shift a
|
||||
else
|
||||
node $ cs.push $ mkNewPath shift a
|
||||
| n _ _ _ := n -- unreachable
|
||||
|
||||
def mkNewTail (t : PersistentArray α) : PersistentArray α :=
|
||||
if t.size <= (mul2Shift 1 (t.shift + initShift)).toNat then
|
||||
{ tail := mkEmptyArray, root := insertNewLeaf t.root (USize.ofNat (t.size - 1)) t.shift t.tail,
|
||||
tailOff := t.size,
|
||||
.. t }
|
||||
else
|
||||
{ tail := Array.empty,
|
||||
root := let n := mkEmptyArray.push t.root in
|
||||
node (n.push (mkNewPath t.shift t.tail)),
|
||||
shift := t.shift + initShift,
|
||||
tailOff := t.size,
|
||||
.. t }
|
||||
|
||||
def tooBig : Nat := usizeSz / 8
|
||||
|
||||
def push (t : PersistentArray α) (a : α) : PersistentArray α :=
|
||||
let r := { tail := t.tail.push a, size := t.size + 1, .. t } in
|
||||
if r.tail.size < branching.toNat || t.size >= tooBig then
|
||||
r
|
||||
else
|
||||
mkNewTail r
|
||||
|
||||
section
|
||||
variables {m : Type v → Type v} [Monad m]
|
||||
local attribute [instance] monadInhabited'
|
||||
|
||||
@[specialize] partial def mfoldlAux (f : β → α → m β) : PersistentArrayNode α → β → m β
|
||||
| (node cs) b := cs.mfoldl (λ b c, mfoldlAux c b) b
|
||||
| (leaf vs) b := vs.mfoldl f b
|
||||
|
||||
@[specialize] def mfoldl (f : β → α → m β) (b : β) (t : PersistentArray α) : m β :=
|
||||
do b ← mfoldlAux f t.root b, t.tail.mfoldl f b
|
||||
|
||||
end
|
||||
|
||||
@[inline] def foldl (f : β → α → β) (b : β) (t : PersistentArray α) : β :=
|
||||
Id.run (t.mfoldl f b)
|
||||
|
||||
def toList (t : PersistentArray α) : List α :=
|
||||
(t.foldl (λ xs x, x :: xs) []).reverse
|
||||
|
||||
section
|
||||
variables {m : Type v → Type v} [Monad m]
|
||||
|
||||
@[specialize] partial def mmapAux (f : α → m β) : PersistentArrayNode α → m (PersistentArrayNode β)
|
||||
| (node cs) := node <$> cs.mmap (λ c, mmapAux c)
|
||||
| (leaf vs) := leaf <$> vs.mmap f
|
||||
|
||||
@[specialize] def mmap (f : α → m β) (t : PersistentArray α) : m (PersistentArray β) :=
|
||||
do
|
||||
root ← mmapAux f t.root,
|
||||
tail ← t.tail.mmap f,
|
||||
pure { tail := tail, root := root, .. t }
|
||||
|
||||
end
|
||||
|
||||
@[inline] def map (f : α → β) (t : PersistentArray α) : PersistentArray β :=
|
||||
Id.run (t.mmap f)
|
||||
|
||||
structure Stats :=
|
||||
(numNodes : Nat) (depth : Nat) (tailSize : Nat)
|
||||
|
||||
partial def collectStats : PersistentArrayNode α → Stats → Nat → Stats
|
||||
| (node cs) s d :=
|
||||
cs.foldl (λ s c, collectStats c s (d+1))
|
||||
{ numNodes := s.numNodes + 1,
|
||||
depth := Nat.max d s.depth, .. s }
|
||||
| (leaf vs) s d := { numNodes := s.numNodes + 1, depth := Nat.max d s.depth, .. s }
|
||||
|
||||
def stats (r : PersistentArray α) : Stats :=
|
||||
collectStats r.root { numNodes := 0, depth := 0, tailSize := r.tail.size } 0
|
||||
|
||||
def Stats.toString (s : Stats) : String :=
|
||||
toString [s.numNodes, s.depth, s.tailSize]
|
||||
|
||||
instance : HasToString Stats := ⟨Stats.toString⟩
|
||||
|
||||
end PersistentArray
|
||||
|
||||
def List.toPersistentArrayAux {α : Type u} : List α → PersistentArray α → PersistentArray α
|
||||
| [] t := t
|
||||
| (x::xs) t := List.toPersistentArrayAux xs (t.push x)
|
||||
|
||||
def List.toPersistentArray {α : Type u} (xs : List α) : PersistentArray α :=
|
||||
xs.toPersistentArrayAux {}
|
||||
7
library/init/data/persistentarray/default.lean
Normal file
7
library/init/data/persistentarray/default.lean
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
/-
|
||||
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import init.data.persistentarray.basic
|
||||
|
|
@ -194,10 +194,6 @@ void object_compactor::operator()(object * o) {
|
|||
case object_kind::Task: r = insert_task(curr); break;
|
||||
case object_kind::Ref: r = insert_ref(curr); break;
|
||||
case object_kind::External: throw exception("external objects cannot be compacted");
|
||||
case object_kind::PArrayRoot:
|
||||
case object_kind::PArraySet:
|
||||
case object_kind::PArrayPush:
|
||||
case object_kind::PArrayPop: throw exception("persistent array objects cannot be compacted");
|
||||
}
|
||||
if (r) m_todo.pop_back();
|
||||
}
|
||||
|
|
@ -315,10 +311,6 @@ object * compacted_region::read() {
|
|||
case object_kind::Ref: fix_ref(curr); break;
|
||||
case object_kind::Task: lean_unreachable();
|
||||
case object_kind::External: lean_unreachable();
|
||||
case object_kind::PArrayRoot: lean_unreachable();
|
||||
case object_kind::PArraySet: lean_unreachable();
|
||||
case object_kind::PArrayPush: lean_unreachable();
|
||||
case object_kind::PArrayPop: lean_unreachable();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -57,18 +57,6 @@ obj_res set_io_error(obj_arg r, std::string const & msg) {
|
|||
return set_io_error(r, mk_io_user_error(mk_string(msg)));
|
||||
}
|
||||
|
||||
static obj_res option_of_io_result(obj_arg r) {
|
||||
if (io_result_is_ok(r)) {
|
||||
object * o = alloc_cnstr(1, 1, 0);
|
||||
cnstr_set(o, 0, io_result_get_value(r));
|
||||
dec(r);
|
||||
return o;
|
||||
} else {
|
||||
dec(r);
|
||||
return box(0);
|
||||
}
|
||||
}
|
||||
|
||||
static bool g_initializing = true;
|
||||
void io_mark_end_initialization() {
|
||||
g_initializing = false;
|
||||
|
|
|
|||
|
|
@ -56,10 +56,6 @@ size_t obj_byte_size(object * o) {
|
|||
case object_kind::MPZ: return sizeof(mpz_object);
|
||||
case object_kind::Thunk: return sizeof(thunk_object);
|
||||
case object_kind::Task: return sizeof(task_object);
|
||||
case object_kind::PArrayRoot: return sizeof(parray_object);
|
||||
case object_kind::PArraySet: return sizeof(parray_object);
|
||||
case object_kind::PArrayPush: return sizeof(parray_object);
|
||||
case object_kind::PArrayPop: return sizeof(parray_object);
|
||||
case object_kind::Ref: return sizeof(ref_object);
|
||||
case object_kind::External: lean_unreachable();
|
||||
}
|
||||
|
|
@ -76,10 +72,6 @@ size_t obj_header_size(object * o) {
|
|||
case object_kind::MPZ: return sizeof(mpz_object);
|
||||
case object_kind::Thunk: return sizeof(thunk_object);
|
||||
case object_kind::Task: return sizeof(task_object);
|
||||
case object_kind::PArrayRoot: return sizeof(parray_object);
|
||||
case object_kind::PArraySet: return sizeof(parray_object);
|
||||
case object_kind::PArrayPush: return sizeof(parray_object);
|
||||
case object_kind::PArrayPop: return sizeof(parray_object);
|
||||
case object_kind::Ref: return sizeof(ref_object);
|
||||
case object_kind::External: lean_unreachable();
|
||||
}
|
||||
|
|
@ -126,23 +118,6 @@ inline void dec(object * o, object* & todo) {
|
|||
|
||||
void deactivate_task(task_object * t);
|
||||
|
||||
static size_t parray_data_capacity(object ** data) {
|
||||
return reinterpret_cast<size_t*>(data)[-1];
|
||||
}
|
||||
|
||||
static object ** alloc_parray_data(size_t c) {
|
||||
size_t * mem = static_cast<size_t*>(malloc(sizeof(object*)*c + sizeof(size_t)));
|
||||
*mem = c;
|
||||
mem++;
|
||||
return reinterpret_cast<object**>(mem);
|
||||
}
|
||||
|
||||
static void dealloc_parray_data(object ** data) {
|
||||
size_t * mem = reinterpret_cast<size_t*>(data);
|
||||
mem--;
|
||||
free(mem);
|
||||
}
|
||||
|
||||
#ifdef LEAN_SMALL_ALLOCATOR
|
||||
static inline void free_heap_obj_core(object * o, size_t sz) {
|
||||
#else
|
||||
|
|
@ -214,10 +189,6 @@ static inline void free_task_obj(object * o) {
|
|||
FREE_OBJ(o, sizeof(task_object) + sizeof(rc_type));
|
||||
}
|
||||
|
||||
static inline void free_parray_obj(object * o) {
|
||||
FREE_OBJ(o, sizeof(parray_object) + sizeof(rc_type));
|
||||
}
|
||||
|
||||
static inline void free_external_obj(object * o) {
|
||||
FREE_OBJ(o, sizeof(external_object) + sizeof(rc_type));
|
||||
}
|
||||
|
|
@ -265,24 +236,6 @@ static void del_core(object * o, object * & todo) {
|
|||
case object_kind::Task:
|
||||
deactivate_task(to_task(o));
|
||||
break;
|
||||
case object_kind::PArrayPop:
|
||||
dec_ref(to_parray(o)->m_next, todo);
|
||||
free_parray_obj(o);
|
||||
break;
|
||||
case object_kind::PArrayPush:
|
||||
case object_kind::PArraySet:
|
||||
dec(to_parray(o)->m_elem, todo);
|
||||
dec_ref(to_parray(o)->m_next, todo);
|
||||
free_parray_obj(o);
|
||||
break;
|
||||
case object_kind::PArrayRoot: {
|
||||
object ** it = to_parray(o)->m_data;
|
||||
object ** end = it + to_parray(o)->m_size;
|
||||
for (; it != end; ++it) dec(*it, todo);
|
||||
dealloc_parray_data(to_parray(o)->m_data);
|
||||
free_parray_obj(o);
|
||||
break;
|
||||
}
|
||||
case object_kind::External:
|
||||
to_external(o)->m_class->m_finalize(to_external(o)->m_data);
|
||||
free_external_obj(o);
|
||||
|
|
@ -330,184 +283,6 @@ object * array_mk_empty() {
|
|||
return g_array_empty;
|
||||
}
|
||||
|
||||
// =======================================
|
||||
// Persistent arrays
|
||||
|
||||
static object ** parray_data_expand(object ** data, size_t sz) {
|
||||
size_t curr_capacity = parray_data_capacity(data);
|
||||
size_t new_capacity = curr_capacity == 0 ? 2 : (3 * curr_capacity + 1) >> 1;
|
||||
object ** new_data = alloc_parray_data(new_capacity);
|
||||
memcpy(new_data, data, sizeof(object*)*sz);
|
||||
dealloc_parray_data(data);
|
||||
return new_data;
|
||||
}
|
||||
|
||||
/* Given `c -> ... -> root`,
|
||||
revert links and make `c` to be the new root:
|
||||
`c <- ... <- root` */
|
||||
static void parray_reroot(object * c) {
|
||||
lean_assert(get_kind(c) != object_kind::PArrayRoot);
|
||||
parray_object * it = to_parray(c);
|
||||
parray_object * prev = nullptr;
|
||||
/* invert links */
|
||||
while (get_kind(it) != object_kind::PArrayRoot) {
|
||||
/* c <- ... <- prev, it -> it_next -> ... -> root
|
||||
c <- ... <- prev <- it, it_next -> ... -> root */
|
||||
parray_object * it_next = it->m_next;
|
||||
it->m_next = prev;
|
||||
prev = it;
|
||||
it = it_next;
|
||||
}
|
||||
lean_assert(prev != nullptr);
|
||||
lean_assert(get_kind(it) == object_kind::PArrayRoot);
|
||||
lean_assert(it != c);
|
||||
object * old_root = it;
|
||||
it->m_next = prev;
|
||||
object ** data = it->m_data;
|
||||
size_t sz = it->m_size;
|
||||
prev = it;
|
||||
it = it->m_next;
|
||||
/* move array to `c` */
|
||||
while (true) {
|
||||
lean_assert(prev != nullptr && it != prev);
|
||||
switch (get_kind(it)) {
|
||||
case object_kind::PArraySet:
|
||||
prev->m_kind = static_cast<unsigned>(object_kind::PArraySet);
|
||||
prev->m_idx = it->m_idx;
|
||||
prev->m_elem = data[it->m_idx];
|
||||
data[it->m_idx] = it->m_elem;
|
||||
break;
|
||||
case object_kind::PArrayPush:
|
||||
if (sz == parray_data_capacity(data))
|
||||
data = parray_data_expand(data, sz);
|
||||
prev->m_kind = static_cast<unsigned>(object_kind::PArrayPop);
|
||||
data[sz] = it->m_elem;
|
||||
sz++;
|
||||
break;
|
||||
case object_kind::PArrayPop:
|
||||
--sz;
|
||||
prev->m_kind = static_cast<unsigned>(object_kind::PArrayPush);
|
||||
prev->m_elem = data[sz];
|
||||
break;
|
||||
default:
|
||||
lean_unreachable();
|
||||
}
|
||||
if (it == c)
|
||||
break;
|
||||
prev = it;
|
||||
it = it->m_next;
|
||||
}
|
||||
lean_assert(it == c);
|
||||
it->m_kind = static_cast<unsigned>(object_kind::PArrayRoot);
|
||||
it->m_data = data;
|
||||
it->m_size = sz;
|
||||
dec_ref(old_root);
|
||||
inc_ref(c);
|
||||
}
|
||||
|
||||
static parray_object * move_parray_root(parray_object * src) {
|
||||
lean_assert(src->m_kind == static_cast<unsigned>(object_kind::PArrayRoot));
|
||||
lean_assert(get_rc(src) > 1);
|
||||
dec_ref(src);
|
||||
parray_object * r = new (alloc_heap_object(sizeof(parray_object))) parray_object();
|
||||
r->m_data = src->m_data;
|
||||
r->m_size = src->m_size;
|
||||
return r;
|
||||
}
|
||||
|
||||
obj_res alloc_parray(size_t capacity) {
|
||||
parray_object * r = new (alloc_heap_object(sizeof(parray_object))) parray_object();
|
||||
r->m_data = alloc_parray_data(capacity);
|
||||
r->m_size = 0;
|
||||
return r;
|
||||
}
|
||||
|
||||
size_t parray_size(b_obj_arg o) {
|
||||
if (get_kind(o) != object_kind::PArrayRoot)
|
||||
parray_reroot(o);
|
||||
return to_parray(o)->m_size;
|
||||
}
|
||||
|
||||
b_obj_res parray_get(b_obj_arg o, size_t i) {
|
||||
if (get_kind(o) != object_kind::PArrayRoot)
|
||||
parray_reroot(o);
|
||||
return to_parray(o)->m_data[i];
|
||||
}
|
||||
|
||||
obj_res parray_set(obj_arg o, size_t i, obj_arg v) {
|
||||
if (get_kind(o) != object_kind::PArrayRoot)
|
||||
parray_reroot(o);
|
||||
parray_object * p = to_parray(o);
|
||||
if (get_rc(p) == 1) {
|
||||
dec(p->m_data[i]);
|
||||
p->m_data[i] = v;
|
||||
return p;
|
||||
} else {
|
||||
parray_object * r = move_parray_root(p);
|
||||
p->m_kind = static_cast<unsigned>(object_kind::PArraySet);
|
||||
p->m_idx = i;
|
||||
p->m_elem = r->m_data[i];
|
||||
p->m_next = r;
|
||||
inc_ref(r);
|
||||
r->m_data[i] = v;
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
obj_res parray_push(obj_arg o, obj_arg v) {
|
||||
if (get_kind(o) != object_kind::PArrayRoot)
|
||||
parray_reroot(o);
|
||||
parray_object * p = to_parray(o);
|
||||
if (p->m_size == parray_data_capacity(p->m_data))
|
||||
p->m_data = parray_data_expand(p->m_data, p->m_size);
|
||||
if (get_rc(p) == 1) {
|
||||
p->m_data[p->m_size] = v;
|
||||
p->m_size++;
|
||||
return p;
|
||||
} else {
|
||||
parray_object * r = move_parray_root(p);
|
||||
p->m_kind = static_cast<unsigned>(object_kind::PArrayPop);
|
||||
p->m_next = r;
|
||||
inc_ref(r);
|
||||
r->m_data[r->m_size] = v;
|
||||
r->m_size++;
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
obj_res parray_pop(obj_arg o) {
|
||||
if (get_kind(o) != object_kind::PArrayRoot)
|
||||
parray_reroot(o);
|
||||
parray_object * p = to_parray(o);
|
||||
if (get_rc(p) == 1) {
|
||||
p->m_size--;
|
||||
dec(p->m_data[p->m_size]);
|
||||
return p;
|
||||
} else {
|
||||
parray_object * r = move_parray_root(p);
|
||||
r->m_size--;
|
||||
p->m_kind = static_cast<unsigned>(object_kind::PArrayPush);
|
||||
p->m_elem = r->m_data[r->m_size];
|
||||
p->m_next = r;
|
||||
inc_ref(r);
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
obj_res parray_copy(b_obj_arg o) {
|
||||
if (get_kind(o) != object_kind::PArrayRoot)
|
||||
parray_reroot(o);
|
||||
size_t sz = to_parray(o)->m_size;
|
||||
object ** data = to_parray(o)->m_data;
|
||||
parray_object * r = new (alloc_heap_object(sizeof(parray_object))) parray_object();
|
||||
r->m_size = sz;
|
||||
r->m_data = alloc_parray_data(parray_data_capacity(data));
|
||||
memcpy(r->m_data, data, sz);
|
||||
for (size_t i = 0; i < sz; i++)
|
||||
inc(data[i]);
|
||||
return r;
|
||||
}
|
||||
|
||||
// =======================================
|
||||
// Closures
|
||||
|
||||
|
|
@ -878,13 +653,6 @@ void mark_mt(object * o) {
|
|||
case object_kind::String:
|
||||
case object_kind::MPZ:
|
||||
return;
|
||||
case object_kind::PArrayPop:
|
||||
case object_kind::PArrayPush:
|
||||
case object_kind::PArraySet:
|
||||
case object_kind::PArrayRoot:
|
||||
/* `mark_mt` cannot be used with parray. They must be copied when used in multiple threads. */
|
||||
lean_unreachable();
|
||||
return;
|
||||
case object_kind::External: {
|
||||
object * fn = alloc_closure(reinterpret_cast<void*>(mark_mt_fn), 1, 0);
|
||||
to_external(o)->m_class->m_foreach(to_external(o)->m_data, fn);
|
||||
|
|
@ -1082,20 +850,6 @@ void mark_persistent(object * o) {
|
|||
case object_kind::String:
|
||||
case object_kind::MPZ:
|
||||
return;
|
||||
case object_kind::PArrayPop:
|
||||
mark_persistent(to_parray(o)->m_next);
|
||||
return;
|
||||
case object_kind::PArrayPush:
|
||||
case object_kind::PArraySet:
|
||||
mark_persistent(to_parray(o)->m_elem);
|
||||
mark_persistent(to_parray(o)->m_next);
|
||||
return;
|
||||
case object_kind::PArrayRoot: {
|
||||
object ** it = to_parray(o)->m_data;
|
||||
object ** end = it + to_parray(o)->m_size;
|
||||
for (; it != end; ++it) mark_persistent(*it);
|
||||
return;
|
||||
}
|
||||
case object_kind::External: {
|
||||
object * fn = alloc_closure(reinterpret_cast<void*>(mark_persistent_fn), 1, 0);
|
||||
to_external(o)->m_class->m_foreach(to_external(o)->m_data, fn);
|
||||
|
|
|
|||
|
|
@ -64,7 +64,6 @@ extern atomic<uint64> g_num_del;
|
|||
enum class object_memory_kind { MTHeap = 0, STHeap, Persistent, Stack, Region };
|
||||
|
||||
enum class object_kind { Constructor, Closure, Array, ScalarArray,
|
||||
PArrayRoot, PArraySet, PArrayPush, PArrayPop,
|
||||
String, MPZ, Thunk, Task, Ref, External };
|
||||
|
||||
/* Objects are initially allocated as STHeap. When we create a task, we change it to MTHeap. */
|
||||
|
|
@ -162,23 +161,6 @@ struct string_object : public object {
|
|||
object(object_kind::String, m), m_size(sz), m_capacity(c), m_length(len) {}
|
||||
};
|
||||
|
||||
/* Persistent arrays are implemented using 4 different kinds of cell:
|
||||
PArraySet, PArrayPush, PArrayPop and PArrayRoot. */
|
||||
struct parray_object : public object {
|
||||
parray_object * m_next; // PArraySet, PArrayPush, PArrayPop
|
||||
union {
|
||||
size_t m_idx; // PArraySet
|
||||
size_t m_size; // PArrayRoot
|
||||
};
|
||||
union {
|
||||
object ** m_data; // PArrayRoot
|
||||
object * m_elem; // PArrayPush and PArraySet
|
||||
};
|
||||
/* Remark: persistent arrays are single threaded object. The `mark_shared` operation
|
||||
copies it when the RC > 1 */
|
||||
parray_object():object(object_kind::PArrayRoot, object_memory_kind::STHeap) {}
|
||||
};
|
||||
|
||||
/* Note that `m_fun` is a pointer to a C function.
|
||||
The `apply` operator performs a cast operation. It casts m_fun to a C function pointer of the right arity.
|
||||
|
||||
|
|
@ -383,7 +365,6 @@ inline bool is_cnstr(object * o) { return get_kind(o) == object_kind::Constructo
|
|||
inline bool is_closure(object * o) { return get_kind(o) == object_kind::Closure; }
|
||||
inline bool is_array(object * o) { return get_kind(o) == object_kind::Array; }
|
||||
inline bool is_sarray(object * o) { return get_kind(o) == object_kind::ScalarArray; }
|
||||
inline bool is_parray(object * o) { auto k = get_kind(o); return k == object_kind::PArrayRoot || k == object_kind::PArraySet || k == object_kind::PArrayPush || k == object_kind::PArrayPop; }
|
||||
inline bool is_string(object * o) { return get_kind(o) == object_kind::String; }
|
||||
inline bool is_mpz(object * o) { return get_kind(o) == object_kind::MPZ; }
|
||||
inline bool is_thunk(object * o) { return get_kind(o) == object_kind::Thunk; }
|
||||
|
|
@ -398,7 +379,6 @@ inline constructor_object * to_cnstr(object * o) { lean_assert(is_cnstr(o)); ret
|
|||
inline closure_object * to_closure(object * o) { lean_assert(is_closure(o)); return static_cast<closure_object*>(o); }
|
||||
inline array_object * to_array(object * o) { lean_assert(is_array(o)); return static_cast<array_object*>(o); }
|
||||
inline sarray_object * to_sarray(object * o) { lean_assert(is_sarray(o)); return static_cast<sarray_object*>(o); }
|
||||
inline parray_object * to_parray(object * o) { lean_assert(is_parray(o)); return static_cast<parray_object*>(o); }
|
||||
inline string_object * to_string(object * o) { lean_assert(is_string(o)); return static_cast<string_object*>(o); }
|
||||
inline mpz_object * to_mpz(object * o) { lean_assert(is_mpz(o)); return static_cast<mpz_object*>(o); }
|
||||
inline thunk_object * to_thunk(object * o) { lean_assert(is_thunk(o)); return static_cast<thunk_object*>(o); }
|
||||
|
|
@ -700,19 +680,6 @@ inline void array_set(u_obj_arg o, size_t i, obj_arg v) {
|
|||
obj_set_data(o, sizeof(array_object) + sizeof(object*)*i, v); // NOLINT
|
||||
}
|
||||
|
||||
// =======================================
|
||||
// Persistent Array of objects
|
||||
|
||||
obj_res alloc_parray(size_t capacity);
|
||||
size_t parray_size(b_obj_arg o);
|
||||
b_obj_res parray_get(b_obj_arg o, size_t i);
|
||||
obj_res parray_set(obj_arg o, size_t i, obj_arg v);
|
||||
obj_res parray_push(obj_arg o, obj_arg v);
|
||||
obj_res parray_pop(obj_arg o);
|
||||
obj_res parray_copy(b_obj_arg o);
|
||||
|
||||
// =======================================
|
||||
|
||||
// =======================================
|
||||
// Array of scalars
|
||||
|
||||
|
|
|
|||
|
|
@ -389,45 +389,6 @@ void tst13() {
|
|||
std::cout << g_task7_counter << "\n";
|
||||
}
|
||||
|
||||
obj_res mk_parray(unsigned n, b_obj_arg v) {
|
||||
object * r = alloc_parray(n);
|
||||
for (unsigned i = 0; i < n; i++) {
|
||||
inc(v);
|
||||
r = parray_push(r, v);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
void tst14() {
|
||||
object * a = mk_parray(10, box(0));
|
||||
lean_assert(parray_size(a) == 10);
|
||||
lean_assert(parray_get(a, 0) == box(0));
|
||||
object * b = a;
|
||||
inc(b);
|
||||
lean_assert(get_rc(a) == 2);
|
||||
lean_assert(get_rc(b) == 2);
|
||||
a = parray_set(a, 0, box(1));
|
||||
lean_assert(a != b);
|
||||
lean_assert(get_rc(b) == 1);
|
||||
lean_assert(get_rc(a) == 2);
|
||||
lean_assert(parray_get(a, 0) == box(1));
|
||||
lean_assert(parray_get(a, 1) == box(0));
|
||||
lean_assert(parray_get(b, 0) == box(0));
|
||||
lean_assert(parray_get(a, 0) == box(1));
|
||||
inc(b);
|
||||
object * c = b;
|
||||
c = parray_push(c, box(20));
|
||||
lean_assert(parray_size(c) == 11);
|
||||
lean_assert(parray_size(a) == 10);
|
||||
lean_assert(parray_size(b) == 10);
|
||||
lean_assert(parray_get(c, 0) == box(0));
|
||||
lean_assert(parray_get(c, 10) == box(20));
|
||||
lean_assert(parray_get(a, 0) == box(1));
|
||||
lean_assert(parray_get(b, 0) == box(0));
|
||||
dec_ref(a); dec_ref(b);
|
||||
lean_assert(get_rc(c) == 1);
|
||||
dec_ref(c);
|
||||
}
|
||||
|
||||
obj_res mk_foo(unsigned n) {
|
||||
object * r = alloc_cnstr(0, 1, 0);
|
||||
|
|
@ -439,106 +400,6 @@ unsigned foo_val(b_obj_arg v) {
|
|||
return unbox(cnstr_get(v, 0));
|
||||
}
|
||||
|
||||
void tst15() {
|
||||
object * v1 = alloc_parray(0);
|
||||
v1 = parray_push(v1, mk_foo(2));
|
||||
v1 = parray_push(v1, mk_foo(3));
|
||||
lean_assert(foo_val(parray_get(v1, 0)) == 2);
|
||||
lean_assert(foo_val(parray_get(v1, 1)) == 3);
|
||||
object * v2 = v1;
|
||||
inc(v2);
|
||||
for (unsigned i = 0; i < 10; i++)
|
||||
v1 = parray_push(v1, mk_foo(i));
|
||||
v1 = parray_set(v1, 0, mk_foo(100));
|
||||
v1 = parray_set(v1, 1, mk_foo(100));
|
||||
lean_assert(parray_size(v2) == 2);
|
||||
lean_assert(foo_val(parray_get(v2, 0)) == 2);
|
||||
lean_assert(foo_val(parray_get(v2, 1)) == 3);
|
||||
object * v3 = v1;
|
||||
inc(v3);
|
||||
v1 = parray_pop(v1);
|
||||
v1 = parray_pop(v1);
|
||||
lean_assert(parray_size(v1) == 10);
|
||||
lean_assert(parray_size(v3) == 12);
|
||||
dec_ref(v1); dec_ref(v2);
|
||||
lean_assert(get_rc(v3) == 1);
|
||||
dec_ref(v3);
|
||||
}
|
||||
|
||||
void driver(unsigned max_sz, unsigned max_val, unsigned num_it,
|
||||
double push_freq,
|
||||
double pop_freq,
|
||||
double set_freq,
|
||||
double copy_freq) {
|
||||
object * v1 = alloc_parray(0);
|
||||
std::vector<unsigned> v2;
|
||||
std::mt19937 rng;
|
||||
rng.seed(static_cast<unsigned int>(time(0)));
|
||||
std::uniform_int_distribution<unsigned int> uint_dist;
|
||||
std::vector<std::pair<object*, std::vector<unsigned>>> copies;
|
||||
lean_assert(get_rc(v1) == 1);
|
||||
size_t acc_sz = 0;
|
||||
for (unsigned i = 0; i < num_it; i++) {
|
||||
acc_sz += parray_size(v1);
|
||||
double f = static_cast<double>(uint_dist(rng) % 10000) / 10000.0;
|
||||
if (f < copy_freq) {
|
||||
inc_ref(v1);
|
||||
copies.emplace_back(v1, v2);
|
||||
}
|
||||
f = static_cast<double>(uint_dist(rng) % 10000) / 10000.0;
|
||||
if (f < push_freq) {
|
||||
if (parray_size(v1) < max_sz) {
|
||||
unsigned a = uint_dist(rng) % max_val;
|
||||
v1 = parray_push(v1, box(a));
|
||||
v2.push_back(a);
|
||||
}
|
||||
}
|
||||
if (parray_size(v1) > 0) {
|
||||
f = static_cast<double>(uint_dist(rng) % 10000) / 10000.0;
|
||||
if (f < pop_freq) {
|
||||
v1 = parray_pop(v1);
|
||||
v2.pop_back();
|
||||
}
|
||||
}
|
||||
if (parray_size(v1) > 0) {
|
||||
f = static_cast<double>(uint_dist(rng) % 10000) / 10000.0;
|
||||
if (f < set_freq) {
|
||||
unsigned idx = uint_dist(rng) % parray_size(v1);
|
||||
unsigned a = uint_dist(rng) % max_val;
|
||||
v1 = parray_set(v1, idx, box(a));
|
||||
v2[idx] = a;
|
||||
}
|
||||
}
|
||||
f = static_cast<double>(uint_dist(rng) % 10000) / 10000.0;
|
||||
lean_assert(parray_size(v1) == v2.size());
|
||||
for (unsigned i = 0; i < v2.size(); i++) {
|
||||
lean_assert(unbox(parray_get(v1, i)) == v2[i]);
|
||||
}
|
||||
}
|
||||
for (std::pair<object *, std::vector<unsigned>> const & p : copies) {
|
||||
lean_assert(parray_size(p.first) == p.second.size());
|
||||
for (unsigned i = 0; i < p.second.size(); i++) {
|
||||
lean_assert(unbox(parray_get(p.first, i)) == p.second[i]);
|
||||
}
|
||||
dec_ref(p.first);
|
||||
}
|
||||
std::cout << "\n";
|
||||
std::cout << "Copies created: " << copies.size() << "\n";
|
||||
std::cout << "Average size: " << static_cast<double>(acc_sz) / static_cast<double>(num_it) << "\n";
|
||||
lean_assert(get_rc(v1) == 1);
|
||||
dec_ref(v1);
|
||||
}
|
||||
|
||||
void tst16() {
|
||||
driver(4, 32, 10000, 0.5, 0.1, 0.5, 0.1);
|
||||
driver(4, 32, 10000, 0.5, 0.1, 0.5, 0.1);
|
||||
driver(4, 32, 10000, 0.5, 0.1, 0.5, 0.5);
|
||||
driver(16, 16, 100000, 0.5, 0.5, 0.5, 0.01);
|
||||
driver(16, 16, 100000, 0.5, 0.1, 0.5, 0.01);
|
||||
driver(16, 16, 100000, 0.5, 0.6, 0.5, 0.01);
|
||||
driver(16, 16, 10000, 0.5, 0.1, 0.5, 0.0);
|
||||
}
|
||||
|
||||
object * mk_list(unsigned n) {
|
||||
object * r = box(0);
|
||||
for (unsigned i = 0; i < n; i++) {
|
||||
|
|
@ -878,9 +739,6 @@ int main() {
|
|||
tst11();
|
||||
tst12();
|
||||
tst13();
|
||||
tst14();
|
||||
tst15();
|
||||
tst16();
|
||||
// tst17(40000, 3000);
|
||||
tst17(400, 30);
|
||||
// tst18(4000, 3000);
|
||||
|
|
|
|||
28
tests/playground/persistentarray.lean
Normal file
28
tests/playground/persistentarray.lean
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
import init.data.persistentarray
|
||||
|
||||
abbrev MyArray := PersistentArray Nat
|
||||
-- abbrev MyArray := Array Nat
|
||||
|
||||
def mkMyArray (n : Nat) : MyArray :=
|
||||
n.fold (λ i s, s.push i) { PersistentArray . }
|
||||
-- n.fold (λ i s, s.push i) Array.empty
|
||||
|
||||
def check (n : Nat) (p : Nat → Nat → Bool) (s : MyArray) : IO Unit :=
|
||||
n.mfor $ λ i, unless (p i (s.get i)) (throw (IO.userError ("failed at " ++ toString i ++ " " ++ toString (s.get i))))
|
||||
|
||||
def inc1 (n : Nat) (s : MyArray) : MyArray :=
|
||||
n.fold (λ i s, s.set i (s.get i + 1)) s
|
||||
|
||||
def checkId (n : Nat) (s : MyArray) : IO Unit :=
|
||||
check n (==) s
|
||||
|
||||
def main (xs : List String) : IO Unit :=
|
||||
do
|
||||
let n := xs.head.toNat,
|
||||
let t := mkMyArray n,
|
||||
checkId n t,
|
||||
let t := inc1 n t,
|
||||
check n (λ i v, v == i + 1) t,
|
||||
IO.println t.size,
|
||||
IO.println t.stats,
|
||||
pure ()
|
||||
|
|
@ -1,200 +0,0 @@
|
|||
import init.lean.format
|
||||
open Lean
|
||||
universes u v w
|
||||
|
||||
inductive RadixNode (α : Type u)
|
||||
| node (cs : Array RadixNode) : RadixNode
|
||||
| leaf (vs : Array α) : RadixNode
|
||||
|
||||
instance RadixNode.inhabited {α : Type u} : Inhabited (RadixNode α) :=
|
||||
⟨RadixNode.leaf Array.empty⟩
|
||||
|
||||
abbrev RadixTree.initShift : USize := 5
|
||||
abbrev RadixTree.branching : USize := USize.ofNat (2 ^ RadixTree.initShift.toNat)
|
||||
|
||||
structure RadixTree (α : Type u) :=
|
||||
/- Recall that we run out of memory if we have more than `usizeSz/8` elements.
|
||||
So, we can stop adding elements at `root` after `size > usizeSz`, and
|
||||
keep growing the `tail`. This modification allow us to use `USize` instead
|
||||
of `Nat` when traversing `root`. -/
|
||||
(root : RadixNode α := RadixNode.node (Array.mkEmpty RadixTree.branching.toNat))
|
||||
(tail : Array α := Array.mkEmpty RadixTree.branching.toNat)
|
||||
(size : Nat := 0)
|
||||
(shift : USize := RadixTree.initShift)
|
||||
(tailOff : Nat := 0)
|
||||
|
||||
namespace RadixTree
|
||||
/- TODO:
|
||||
- Use proofs for showing that array accesses are not out of bounds.
|
||||
-/
|
||||
variables {α : Type u} {β : Type v}
|
||||
open RadixNode
|
||||
|
||||
instance : Inhabited (RadixTree α) := ⟨{}⟩
|
||||
|
||||
def mkEmptyArray : Array α := Array.mkEmpty branching.toNat
|
||||
|
||||
abbrev mul2Shift (i : USize) (shift : USize) : USize := USize.shift_left i shift
|
||||
abbrev div2Shift (i : USize) (shift : USize) : USize := USize.shift_right i shift
|
||||
abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shift_left 1 shift) - 1)
|
||||
|
||||
partial def getAux [Inhabited α] : RadixNode α → USize → USize → α
|
||||
| (node cs) i shift := getAux (cs.get (div2Shift i shift).toNat) (mod2Shift i shift) (shift - initShift)
|
||||
| (leaf cs) i _ := cs.get i.toNat
|
||||
|
||||
def get [Inhabited α] (t : RadixTree α) (i : Nat) : α :=
|
||||
if i >= t.tailOff then
|
||||
t.tail.get (i - t.tailOff)
|
||||
else
|
||||
getAux t.root (USize.ofNat i) t.shift
|
||||
|
||||
partial def setAux : RadixNode α → USize → USize → α → RadixNode α
|
||||
| (node cs) i shift a := node (cs.modify (div2Shift i shift).toNat $ λ c,
|
||||
setAux c (mod2Shift i shift) (shift - initShift) a)
|
||||
| (leaf cs) i _ a := leaf (cs.set i.toNat a)
|
||||
|
||||
def set (t : RadixTree α) (i : Nat) (a : α) : RadixTree α :=
|
||||
if i >= t.tailOff then
|
||||
{ tail := t.tail.set (i - t.tailOff) a, .. t }
|
||||
else
|
||||
{ root := setAux t.root (USize.ofNat i) t.shift a, .. t }
|
||||
|
||||
partial def mkNewPath : USize → Array α → RadixNode α
|
||||
| shift a :=
|
||||
if shift == 0 then
|
||||
leaf a
|
||||
else
|
||||
node (mkEmptyArray.push (mkNewPath (shift - initShift) a))
|
||||
|
||||
partial def insertNewLeaf : RadixNode α → USize → USize → Array α → RadixNode α
|
||||
| (node cs) i shift a :=
|
||||
if i < branching then
|
||||
node (cs.push (leaf a))
|
||||
else
|
||||
let j := div2Shift i shift in
|
||||
if j.toNat < cs.size then
|
||||
node (cs.modify j.toNat $ λ c, insertNewLeaf c (mod2Shift i shift) (shift - initShift) a)
|
||||
else
|
||||
node (cs.push (mkNewPath (shift - initShift) a))
|
||||
| n _ _ _ := n -- unreachable
|
||||
|
||||
def mkNewTail (t : RadixTree α) : RadixTree α :=
|
||||
if t.size <= (mul2Shift 1 (t.shift + initShift)).toNat then
|
||||
{ tail := mkEmptyArray, root := insertNewLeaf t.root (USize.ofNat (t.size - 1)) t.shift t.tail,
|
||||
tailOff := t.size,
|
||||
.. t }
|
||||
else
|
||||
{ tail := Array.empty,
|
||||
root := let n := mkEmptyArray.push t.root in
|
||||
node (n.push (mkNewPath t.shift t.tail)),
|
||||
shift := t.shift + initShift,
|
||||
tailOff := t.size,
|
||||
.. t }
|
||||
|
||||
def tooBig : Nat := usizeSz / 8
|
||||
|
||||
def push (t : RadixTree α) (a : α) : RadixTree α :=
|
||||
let r := { tail := t.tail.push a, size := t.size + 1, .. t } in
|
||||
if r.tail.size < branching.toNat || t.size >= tooBig then
|
||||
r
|
||||
else
|
||||
mkNewTail r
|
||||
|
||||
section
|
||||
variables {m : Type v → Type v} [Monad m]
|
||||
local attribute [instance] monadInhabited'
|
||||
|
||||
@[specialize] partial def mfoldlAux (f : β → α → m β) : RadixNode α → β → m β
|
||||
| (node cs) b := cs.mfoldl (λ b c, mfoldlAux c b) b
|
||||
| (leaf vs) b := vs.mfoldl f b
|
||||
|
||||
@[specialize] def mfoldl (f : β → α → m β) (b : β) (t : RadixTree α) : m β :=
|
||||
do b ← mfoldlAux f t.root b, t.tail.mfoldl f b
|
||||
|
||||
end
|
||||
|
||||
@[inline] def foldl (f : β → α → β) (b : β) (t : RadixTree α) : β :=
|
||||
Id.run (t.mfoldl f b)
|
||||
|
||||
def toList (t : RadixTree α) : List α :=
|
||||
(t.foldl (λ xs x, x :: xs) []).reverse
|
||||
|
||||
section
|
||||
variables {m : Type v → Type v} [Monad m]
|
||||
|
||||
@[specialize] partial def mmapAux (f : α → m β) : RadixNode α → m (RadixNode β)
|
||||
| (node cs) := node <$> cs.mmap (λ c, mmapAux c)
|
||||
| (leaf vs) := leaf <$> vs.mmap f
|
||||
|
||||
@[specialize] def mmap (f : α → m β) (t : RadixTree α) : m (RadixTree β) :=
|
||||
do
|
||||
root ← mmapAux f t.root,
|
||||
tail ← t.tail.mmap f,
|
||||
pure { tail := tail, root := root, .. t }
|
||||
|
||||
end
|
||||
|
||||
@[inline] def map (f : α → β) (t : RadixTree α) : RadixTree β :=
|
||||
Id.run (t.mmap f)
|
||||
|
||||
structure Stats :=
|
||||
(numNodes : Nat) (depth : Nat) (tailSize : Nat)
|
||||
|
||||
partial def collectStats : RadixNode α → Stats → Nat → Stats
|
||||
| (node cs) s d :=
|
||||
cs.foldl (λ s c, collectStats c s (d+1))
|
||||
{ numNodes := s.numNodes + 1,
|
||||
depth := Nat.max d s.depth, .. s }
|
||||
| (leaf vs) s d := { numNodes := s.numNodes + 1, depth := Nat.max d s.depth, .. s }
|
||||
|
||||
def stats (r : RadixTree α) : Stats :=
|
||||
collectStats r.root { numNodes := 0, depth := 0, tailSize := r.tail.size } 0
|
||||
|
||||
def Stats.toString (s : Stats) : String :=
|
||||
toString [s.numNodes, s.depth, s.tailSize]
|
||||
|
||||
instance : HasToString Stats := ⟨Stats.toString⟩
|
||||
|
||||
partial def formatRawAux [HasFormat α] : RadixNode α → Format
|
||||
| (node cs) := "Node" ++ Format.sbracket (cs.foldl (λ f c, f ++ Format.line ++ formatRawAux c) Format.nil)
|
||||
| (leaf cs) := format cs.toList
|
||||
|
||||
partial def formatRaw [HasFormat α] (t : RadixTree α) : Format :=
|
||||
Format.bracket "{" ("root :=" ++ Format.line ++ formatRawAux t.root ++ "," ++ Format.line ++
|
||||
"tail :=" ++ Format.line ++ format t.tail.toList) "}"
|
||||
end RadixTree
|
||||
|
||||
def List.toRadixTreeAux {α : Type u} : List α → RadixTree α → RadixTree α
|
||||
| [] t := t
|
||||
| (x::xs) t := List.toRadixTreeAux xs (t.push x)
|
||||
|
||||
def List.toRadixTree {α : Type u} (xs : List α) : RadixTree α :=
|
||||
xs.toRadixTreeAux {}
|
||||
|
||||
abbrev PArray := RadixTree Nat
|
||||
-- abbrev PArray := Array Nat
|
||||
|
||||
def mkRadixTree (n : Nat) : PArray :=
|
||||
n.fold (λ i s, s.push i) { RadixTree . }
|
||||
-- n.fold (λ i s, s.push i) Array.empty
|
||||
|
||||
def check (n : Nat) (p : Nat → Nat → Bool) (s : PArray) : IO Unit :=
|
||||
n.mfor $ λ i, unless (p i (s.get i)) (throw (IO.userError ("failed at " ++ toString i ++ " " ++ toString (s.get i))))
|
||||
|
||||
def inc1 (n : Nat) (s : PArray) : PArray :=
|
||||
n.fold (λ i s, s.set i (s.get i + 1)) s
|
||||
|
||||
def checkId (n : Nat) (s : PArray) : IO Unit :=
|
||||
check n (==) s
|
||||
|
||||
def main (xs : List String) : IO Unit :=
|
||||
do
|
||||
let n := xs.head.toNat,
|
||||
let t := mkRadixTree n,
|
||||
-- IO.println t.formatRaw *>
|
||||
checkId n t,
|
||||
let t := inc1 n t,
|
||||
check n (λ i v, v == i + 1) t,
|
||||
IO.println t.size,
|
||||
IO.println t.stats,
|
||||
pure ()
|
||||
Loading…
Add table
Reference in a new issue