lean4-htt/library/init/data/persistentarray/basic.lean
2019-06-24 15:48:11 -07:00

189 lines
6.8 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import init.data.array
universes u v w
inductive PersistentArrayNode (α : Type u)
| node (cs : Array PersistentArrayNode) : PersistentArrayNode
| leaf (vs : Array α) : PersistentArrayNode
instance PersistentArrayNode.inhabited {α : Type u} : Inhabited (PersistentArrayNode α) :=
⟨PersistentArrayNode.leaf Array.empty⟩
abbrev PersistentArray.initShift : USize := 5
abbrev PersistentArray.branching : USize := USize.ofNat (2 ^ PersistentArray.initShift.toNat)
structure PersistentArray (α : Type u) :=
/- Recall that we run out of memory if we have more than `usizeSz/8` elements.
So, we can stop adding elements at `root` after `size > usizeSz`, and
keep growing the `tail`. This modification allow us to use `USize` instead
of `Nat` when traversing `root`. -/
(root : PersistentArrayNode α := PersistentArrayNode.node (Array.mkEmpty PersistentArray.branching.toNat))
(tail : Array α := Array.mkEmpty PersistentArray.branching.toNat)
(size : Nat := 0)
(shift : USize := PersistentArray.initShift)
(tailOff : Nat := 0)
abbrev PArray (α : Type u) := PersistentArray α
namespace PersistentArray
/- TODO: use proofs for showing that array accesses are not out of bounds.
We can do it after we reimplement the tactic framework. -/
variables {α : Type u} {β : Type v}
open PersistentArrayNode
instance : Inhabited (PersistentArray α) := ⟨{}⟩
def mkEmptyArray : Array α := Array.mkEmpty branching.toNat
abbrev mul2Shift (i : USize) (shift : USize) : USize := USize.shift_left i shift
abbrev div2Shift (i : USize) (shift : USize) : USize := USize.shift_right i shift
abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shift_left 1 shift) - 1)
partial def getAux [Inhabited α] : PersistentArrayNode α → USize → USize → α
| (node cs) i shift := getAux (cs.get (div2Shift i shift).toNat) (mod2Shift i shift) (shift - initShift)
| (leaf cs) i _ := cs.get i.toNat
def get [Inhabited α] (t : PersistentArray α) (i : Nat) : α :=
if i >= t.tailOff then
t.tail.get (i - t.tailOff)
else
getAux t.root (USize.ofNat i) t.shift
partial def setAux : PersistentArrayNode α → USize → USize → α → PersistentArrayNode α
| (node cs) i shift a :=
let j := div2Shift i shift in
let i := mod2Shift i shift in
let shift := shift - initShift in
node $ cs.modify j.toNat $ λ c, setAux c i shift a
| (leaf cs) i _ a := leaf (cs.set i.toNat a)
def set (t : PersistentArray α) (i : Nat) (a : α) : PersistentArray α :=
if i >= t.tailOff then
{ tail := t.tail.set (i - t.tailOff) a, .. t }
else
{ root := setAux t.root (USize.ofNat i) t.shift a, .. t }
@[specialize] partial def modifyAux [Inhabited α] (f : αα) : PersistentArrayNode α → USize → USize → PersistentArrayNode α
| (node cs) i shift :=
let j := div2Shift i shift in
let i := mod2Shift i shift in
let shift := shift - initShift in
node $ cs.modify j.toNat $ λ c, modifyAux c i shift
| (leaf cs) i _ := leaf (cs.modify i.toNat f)
@[specialize] def modify [Inhabited α] (t : PersistentArray α) (i : Nat) (f : αα) : PersistentArray α :=
if i >= t.tailOff then
{ tail := t.tail.modify (i - t.tailOff) f, .. t }
else
{ root := modifyAux f t.root (USize.ofNat i) t.shift, .. t }
partial def mkNewPath : USize → Array α → PersistentArrayNode α
| shift a :=
if shift == 0 then
leaf a
else
node (mkEmptyArray.push (mkNewPath (shift - initShift) a))
partial def insertNewLeaf : PersistentArrayNode α → USize → USize → Array α → PersistentArrayNode α
| (node cs) i shift a :=
if i < branching then
node (cs.push (leaf a))
else
let j := div2Shift i shift in
let i := mod2Shift i shift in
let shift := shift - initShift in
if j.toNat < cs.size then
node $ cs.modify j.toNat $ λ c, insertNewLeaf c i shift a
else
node $ cs.push $ mkNewPath shift a
| n _ _ _ := n -- unreachable
def mkNewTail (t : PersistentArray α) : PersistentArray α :=
if t.size <= (mul2Shift 1 (t.shift + initShift)).toNat then
{ tail := mkEmptyArray, root := insertNewLeaf t.root (USize.ofNat (t.size - 1)) t.shift t.tail,
tailOff := t.size,
.. t }
else
{ tail := Array.empty,
root := let n := mkEmptyArray.push t.root in
node (n.push (mkNewPath t.shift t.tail)),
shift := t.shift + initShift,
tailOff := t.size,
.. t }
def tooBig : Nat := usizeSz / 8
def push (t : PersistentArray α) (a : α) : PersistentArray α :=
let r := { tail := t.tail.push a, size := t.size + 1, .. t } in
if r.tail.size < branching.toNat || t.size >= tooBig then
r
else
mkNewTail r
section
variables {m : Type v → Type v} [Monad m]
@[specialize] partial def mfoldlAux (f : β → α → m β) : PersistentArrayNode α → β → m β
| (node cs) b := cs.mfoldl (λ b c, mfoldlAux c b) b
| (leaf vs) b := vs.mfoldl f b
@[specialize] def mfoldl (f : β → α → m β) (b : β) (t : PersistentArray α) : m β :=
do b ← mfoldlAux f t.root b, t.tail.mfoldl f b
end
@[inline] def foldl (f : β → α → β) (b : β) (t : PersistentArray α) : β :=
Id.run (t.mfoldl f b)
def toList (t : PersistentArray α) : List α :=
(t.foldl (λ xs x, x :: xs) []).reverse
section
variables {m : Type v → Type v} [Monad m]
@[specialize] partial def mmapAux (f : α → m β) : PersistentArrayNode α → m (PersistentArrayNode β)
| (node cs) := node <$> cs.mmap (λ c, mmapAux c)
| (leaf vs) := leaf <$> vs.mmap f
@[specialize] def mmap (f : α → m β) (t : PersistentArray α) : m (PersistentArray β) :=
do
root ← mmapAux f t.root,
tail ← t.tail.mmap f,
pure { tail := tail, root := root, .. t }
end
@[inline] def map (f : α → β) (t : PersistentArray α) : PersistentArray β :=
Id.run (t.mmap f)
structure Stats :=
(numNodes : Nat) (depth : Nat) (tailSize : Nat)
partial def collectStats : PersistentArrayNode α → Stats → Nat → Stats
| (node cs) s d :=
cs.foldl (λ s c, collectStats c s (d+1))
{ numNodes := s.numNodes + 1,
depth := Nat.max d s.depth, .. s }
| (leaf vs) s d := { numNodes := s.numNodes + 1, depth := Nat.max d s.depth, .. s }
def stats (r : PersistentArray α) : Stats :=
collectStats r.root { numNodes := 0, depth := 0, tailSize := r.tail.size } 0
def Stats.toString (s : Stats) : String :=
toString [s.numNodes, s.depth, s.tailSize]
instance : HasToString Stats := ⟨Stats.toString⟩
end PersistentArray
def List.toPersistentArrayAux {α : Type u} : List α → PersistentArray α → PersistentArray α
| [] t := t
| (x::xs) t := List.toPersistentArrayAux xs (t.push x)
def List.toPersistentArray {α : Type u} (xs : List α) : PersistentArray α :=
xs.toPersistentArrayAux {}