/- Copyright (c) 2019 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ prelude import init.data.array universes u v w inductive PersistentArrayNode (α : Type u) | node (cs : Array PersistentArrayNode) : PersistentArrayNode | leaf (vs : Array α) : PersistentArrayNode instance PersistentArrayNode.inhabited {α : Type u} : Inhabited (PersistentArrayNode α) := ⟨PersistentArrayNode.leaf Array.empty⟩ abbrev PersistentArray.initShift : USize := 5 abbrev PersistentArray.branching : USize := USize.ofNat (2 ^ PersistentArray.initShift.toNat) structure PersistentArray (α : Type u) := /- Recall that we run out of memory if we have more than `usizeSz/8` elements. So, we can stop adding elements at `root` after `size > usizeSz`, and keep growing the `tail`. This modification allow us to use `USize` instead of `Nat` when traversing `root`. -/ (root : PersistentArrayNode α := PersistentArrayNode.node (Array.mkEmpty PersistentArray.branching.toNat)) (tail : Array α := Array.mkEmpty PersistentArray.branching.toNat) (size : Nat := 0) (shift : USize := PersistentArray.initShift) (tailOff : Nat := 0) abbrev PArray (α : Type u) := PersistentArray α namespace PersistentArray /- TODO: use proofs for showing that array accesses are not out of bounds. We can do it after we reimplement the tactic framework. -/ variables {α : Type u} open PersistentArrayNode def empty : PersistentArray α := {} def isEmpty (a : PersistentArray α) : Bool := a.size == 0 instance : Inhabited (PersistentArray α) := ⟨{}⟩ def mkEmptyArray : Array α := Array.mkEmpty branching.toNat abbrev mul2Shift (i : USize) (shift : USize) : USize := USize.shift_left i shift abbrev div2Shift (i : USize) (shift : USize) : USize := USize.shift_right i shift abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shift_left 1 shift) - 1) partial def getAux [Inhabited α] : PersistentArrayNode α → USize → USize → α | (node cs) i shift := getAux (cs.get (div2Shift i shift).toNat) (mod2Shift i shift) (shift - initShift) | (leaf cs) i _ := cs.get i.toNat def get [Inhabited α] (t : PersistentArray α) (i : Nat) : α := if i >= t.tailOff then t.tail.get (i - t.tailOff) else getAux t.root (USize.ofNat i) t.shift partial def setAux : PersistentArrayNode α → USize → USize → α → PersistentArrayNode α | (node cs) i shift a := let j := div2Shift i shift; let i := mod2Shift i shift; let shift := shift - initShift; node $ cs.modify j.toNat $ fun c => setAux c i shift a | (leaf cs) i _ a := leaf (cs.set i.toNat a) def set (t : PersistentArray α) (i : Nat) (a : α) : PersistentArray α := if i >= t.tailOff then { tail := t.tail.set (i - t.tailOff) a, .. t } else { root := setAux t.root (USize.ofNat i) t.shift a, .. t } @[specialize] partial def modifyAux [Inhabited α] (f : α → α) : PersistentArrayNode α → USize → USize → PersistentArrayNode α | (node cs) i shift := let j := div2Shift i shift; let i := mod2Shift i shift; let shift := shift - initShift; node $ cs.modify j.toNat $ fun c => modifyAux c i shift | (leaf cs) i _ := leaf (cs.modify i.toNat f) @[specialize] def modify [Inhabited α] (t : PersistentArray α) (i : Nat) (f : α → α) : PersistentArray α := if i >= t.tailOff then { tail := t.tail.modify (i - t.tailOff) f, .. t } else { root := modifyAux f t.root (USize.ofNat i) t.shift, .. t } partial def mkNewPath : USize → Array α → PersistentArrayNode α | shift a := if shift == 0 then leaf a else node (mkEmptyArray.push (mkNewPath (shift - initShift) a)) partial def insertNewLeaf : PersistentArrayNode α → USize → USize → Array α → PersistentArrayNode α | (node cs) i shift a := if i < branching then node (cs.push (leaf a)) else let j := div2Shift i shift; let i := mod2Shift i shift; let shift := shift - initShift; if j.toNat < cs.size then node $ cs.modify j.toNat $ fun c => insertNewLeaf c i shift a else node $ cs.push $ mkNewPath shift a | n _ _ _ := n -- unreachable def mkNewTail (t : PersistentArray α) : PersistentArray α := if t.size <= (mul2Shift 1 (t.shift + initShift)).toNat then { tail := mkEmptyArray, root := insertNewLeaf t.root (USize.ofNat (t.size - 1)) t.shift t.tail, tailOff := t.size, .. t } else { tail := Array.empty, root := let n := mkEmptyArray.push t.root; node (n.push (mkNewPath t.shift t.tail)), shift := t.shift + initShift, tailOff := t.size, .. t } def tooBig : Nat := usizeSz / 8 def push (t : PersistentArray α) (a : α) : PersistentArray α := let r := { tail := t.tail.push a, size := t.size + 1, .. t }; if r.tail.size < branching.toNat || t.size >= tooBig then r else mkNewTail r private def emptyArray {α : Type u} : Array (PersistentArrayNode α) := Array.mkEmpty PersistentArray.branching.toNat partial def popLeaf : PersistentArrayNode α → Option (Array α) × Array (PersistentArrayNode α) | n@(node cs) := if h : cs.size ≠ 0 then let idx : Fin cs.size := ⟨cs.size - 1, Nat.predLt h⟩; let last := cs.fget idx; match popLeaf last with | (none, _) => (none, emptyArray) | (some l, newLast) => if newLast.size == 0 then let cs := cs.pop; if cs.isEmpty then (some l, emptyArray) else (some l, cs) else (some l, cs.fset idx (node newLast)) else (none, emptyArray) | (leaf vs) := (some vs, emptyArray) def pop (t : PersistentArray α) : PersistentArray α := if t.tail.size > 0 then { tail := t.tail.pop, size := t.size - 1, .. t } else match popLeaf t.root with | (none, _) => t | (some last, newRoots) => let last := last.pop; let newSize := t.size - 1; let newTailOff := newSize - last.size; if newRoots.size == 1 then { root := newRoots.get 0, shift := t.shift - initShift, size := newSize, tail := last, tailOff := newTailOff } else { root := node newRoots, size := newSize, tail := last, tailOff := newTailOff, .. t } section variables {m : Type v → Type w} [Monad m] variable {β : Type v} @[specialize] partial def mfoldlAux (f : β → α → m β) : PersistentArrayNode α → β → m β | (node cs) b := cs.mfoldl (fun b c => mfoldlAux c b) b | (leaf vs) b := vs.mfoldl f b @[specialize] def mfoldl (t : PersistentArray α) (f : β → α → m β) (b : β) : m β := do b ← mfoldlAux f t.root b; t.tail.mfoldl f b @[specialize] partial def mfindAux (f : α → m (Option β)) : PersistentArrayNode α → m (Option β) | (node cs) := cs.mfind (fun c => mfindAux c) | (leaf vs) := vs.mfind f @[specialize] def mfind (t : PersistentArray α) (f : α → m (Option β)) : m (Option β) := do b ← mfindAux f t.root; match b with | none => t.tail.mfind f | some b => pure (some b) @[specialize] partial def mfindRevAux (f : α → m (Option β)) : PersistentArrayNode α → m (Option β) | (node cs) := cs.mfindRev (fun c => mfindRevAux c) | (leaf vs) := vs.mfindRev f @[specialize] def mfindRev (t : PersistentArray α) (f : α → m (Option β)) : m (Option β) := do b ← t.tail.mfindRev f; match b with | none => mfindRevAux f t.root | some b => pure (some b) end @[inline] def foldl {β} (t : PersistentArray α) (f : β → α → β) (b : β) : β := Id.run (t.mfoldl f b) @[inline] def find {β} (t : PersistentArray α) (f : α → (Option β)) : Option β := Id.run (t.mfind f) @[inline] def findRev {β} (t : PersistentArray α) (f : α → (Option β)) : Option β := Id.run (t.mfindRev f) def toList (t : PersistentArray α) : List α := (t.foldl (fun xs x => x :: xs) []).reverse section variables {m : Type u → Type v} [Monad m] variable {β : Type u} @[specialize] partial def mmapAux (f : α → m β) : PersistentArrayNode α → m (PersistentArrayNode β) | (node cs) := node <$> cs.mmap (fun c => mmapAux c) | (leaf vs) := leaf <$> vs.mmap f @[specialize] def mmap (f : α → m β) (t : PersistentArray α) : m (PersistentArray β) := do root ← mmapAux f t.root; tail ← t.tail.mmap f; pure { tail := tail, root := root, .. t } end @[inline] def map {β} (f : α → β) (t : PersistentArray α) : PersistentArray β := Id.run (t.mmap f) structure Stats := (numNodes : Nat) (depth : Nat) (tailSize : Nat) partial def collectStats : PersistentArrayNode α → Stats → Nat → Stats | (node cs) s d := cs.foldl (fun s c => collectStats c s (d+1)) { numNodes := s.numNodes + 1, depth := Nat.max d s.depth, .. s } | (leaf vs) s d := { numNodes := s.numNodes + 1, depth := Nat.max d s.depth, .. s } def stats (r : PersistentArray α) : Stats := collectStats r.root { numNodes := 0, depth := 0, tailSize := r.tail.size } 0 def Stats.toString (s : Stats) : String := "{nodes := " ++ toString s.numNodes ++ ", depth := " ++ toString s.depth ++ ", tail size := " ++ toString s.tailSize ++ "}" instance : HasToString Stats := ⟨Stats.toString⟩ end PersistentArray def List.toPersistentArrayAux {α : Type u} : List α → PersistentArray α → PersistentArray α | [] t := t | (x::xs) t := List.toPersistentArrayAux xs (t.push x) def List.toPersistentArray {α : Type u} (xs : List α) : PersistentArray α := xs.toPersistentArrayAux {}