diff --git a/tests/playground/radixtree.lean b/tests/playground/radixtree.lean index afe41ebd21..8f7e949f6c 100644 --- a/tests/playground/radixtree.lean +++ b/tests/playground/radixtree.lean @@ -9,7 +9,8 @@ inductive RadixNode (α : Type u) instance RadixNode.inhabited {α : Type u} : Inhabited (RadixNode α) := ⟨RadixNode.leaf Array.empty⟩ -abbrev RadixTree.branching : USize := USize.ofNat 32 +abbrev RadixTree.initShift : USize := 5 +abbrev RadixTree.branching : USize := USize.ofNat (2 ^ RadixTree.initShift.toNat) structure RadixTree (α : Type u) := /- Recall that we run out of memory if we have more than `usizeSz/8` elements. @@ -19,13 +20,12 @@ structure RadixTree (α : Type u) := (root : RadixNode α := RadixNode.node (Array.mkEmpty RadixTree.branching.toNat)) (tail : Array α := Array.mkEmpty RadixTree.branching.toNat) (size : Nat := 0) -(mh : USize := RadixTree.branching) +(shift : USize := RadixTree.initShift) (tailOff : Nat := 0) namespace RadixTree /- TODO: - Use proofs for showing that array accesses are not out of bounds. - - Use bit shifting and masking operations instead of `/` and `%`. -/ variables {α : Type u} {β : Type v} open RadixNode @@ -34,55 +34,60 @@ instance : Inhabited (RadixTree α) := ⟨{}⟩ def mkEmptyArray : Array α := Array.mkEmpty branching.toNat +abbrev mul2Shift (i : USize) (shift : USize) : USize := USize.shift_left i shift +abbrev div2Shift (i : USize) (shift : USize) : USize := USize.shift_right i shift +abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shift_left 1 shift) - 1) + partial def getAux [Inhabited α] : RadixNode α → USize → USize → α -| (node cs) i mh := getAux (cs.get (i / mh).toNat) (i % mh) (mh / branching) -| (leaf cs) i _ := cs.get i.toNat +| (node cs) i shift := getAux (cs.get (div2Shift i shift).toNat) (mod2Shift i shift) (shift - initShift) +| (leaf cs) i _ := cs.get i.toNat def get [Inhabited α] (t : RadixTree α) (i : Nat) : α := if i >= t.tailOff then t.tail.get (i - t.tailOff) else - getAux t.root (USize.ofNat i) t.mh + getAux t.root (USize.ofNat i) t.shift partial def setAux : RadixNode α → USize → USize → α → RadixNode α -| (node cs) i mh a := node (cs.modify (i / mh).toNat $ λ c, setAux c (i % mh) (mh / branching) a) -| (leaf cs) i _ a := leaf (cs.set i.toNat a) +| (node cs) i shift a := node (cs.modify (div2Shift i shift).toNat $ λ c, + setAux c (mod2Shift i shift) (shift - initShift) a) +| (leaf cs) i _ a := leaf (cs.set i.toNat a) def set (t : RadixTree α) (i : Nat) (a : α) : RadixTree α := if i >= t.tailOff then { tail := t.tail.set (i - t.tailOff) a, .. t } else - { root := setAux t.root (USize.ofNat i) t.mh a, .. t } + { root := setAux t.root (USize.ofNat i) t.shift a, .. t } partial def mkNewPath : USize → Array α → RadixNode α -| mh a := - if mh <= 1 then +| shift a := + if shift == 0 then leaf a else - node (mkEmptyArray.push (mkNewPath (mh / branching) a)) + node (mkEmptyArray.push (mkNewPath (shift - initShift) a)) partial def insertNewLeaf : RadixNode α → USize → USize → Array α → RadixNode α -| (node cs) i mh a := +| (node cs) i shift a := if i < branching then node (cs.push (leaf a)) else - let j := i / mh in + let j := div2Shift i shift in if j.toNat < cs.size then - node (cs.modify j.toNat $ λ c, insertNewLeaf c (i % mh) (mh / branching) a) + node (cs.modify j.toNat $ λ c, insertNewLeaf c (mod2Shift i shift) (shift - initShift) a) else - node (cs.push (mkNewPath (mh / branching) a)) -| n _ _ _ := n -- unreachable + node (cs.push (mkNewPath (shift - initShift) a)) +| n _ _ _ := n -- unreachable def mkNewTail (t : RadixTree α) : RadixTree α := -if t.size <= (t.mh * branching).toNat then - { tail := mkEmptyArray, root := insertNewLeaf t.root (USize.ofNat (t.size - 1)) t.mh t.tail, +if t.size <= (mul2Shift 1 (t.shift + initShift)).toNat then + { tail := mkEmptyArray, root := insertNewLeaf t.root (USize.ofNat (t.size - 1)) t.shift t.tail, tailOff := t.size, .. t } else { tail := Array.empty, root := let n := mkEmptyArray.push t.root in - node (n.push (mkNewPath t.mh t.tail)), - mh := t.mh * branching, + node (n.push (mkNewPath t.shift t.tail)), + shift := t.shift + initShift, tailOff := t.size, .. t }