diff --git a/library/init/data/array/basic.lean b/library/init/data/array/basic.lean index 3da3053a2e..e1eedeebc2 100644 --- a/library/init/data/array/basic.lean +++ b/library/init/data/array/basic.lean @@ -92,6 +92,9 @@ def fset (a : Array α) (i : @& Fin a.size) (v : α) : Array α := theorem szFSetEq (a : Array α) (i : Fin a.size) (v : α) : (fset a i v).size = a.size := rfl +theorem szPushEq (a : Array α) (v : α) : (push a v).size = a.size + 1 := +rfl + /- Low-level version of `fset` which is as fast as a C array fset. `Fin` values are represented as tag pointers in the Lean runtime. Thus, `fset` may be slightly slower than `uset`. -/ diff --git a/library/init/data/persistenthashmap/basic.lean b/library/init/data/persistenthashmap/basic.lean new file mode 100644 index 0000000000..93c0bf604e --- /dev/null +++ b/library/init/data/persistenthashmap/basic.lean @@ -0,0 +1,209 @@ +/- +Copyright (c) 2019 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import init.data.array +import init.data.hashable +universes u v w w' + +namespace PersistentHashMap + +inductive Entry (α : Type u) (β : Type v) (σ : Type w) +| entry {} (key : α) (val : β) : Entry +| ref {} (node : σ) : Entry +| null {} : Entry + +instance Entry.inhabited {α β σ} : Inhabited (Entry α β σ) := ⟨Entry.null⟩ + +inductive Node (α : Type u) (β : Type v) : Type (max u v) +| entries (es : Array (Entry α β Node)) : Node +| collision (ks : Array α) (vs : Array β) (h : ks.size = vs.size) : Node + +instance Node.inhabited {α β} : Inhabited (Node α β) := ⟨Node.entries Array.empty⟩ + +abbrev shift : USize := 5 +abbrev branching : USize := USize.ofNat (2 ^ shift.toNat) +abbrev maxDepth : USize := 7 +abbrev maxCollisions : Nat := 4 + +def mkEmptyEntriesArray {α β} : Array (Entry α β (Node α β)) := +(Array.mkArray PersistentHashMap.branching.toNat PersistentHashMap.Entry.null) + +end PersistentHashMap + +structure PersistentHashMap (α : Type u) (β : Type v) := +(root : PersistentHashMap.Node α β := PersistentHashMap.Node.entries PersistentHashMap.mkEmptyEntriesArray) +(size : Nat := 0) + +abbrev PHashMap (α : Type u) (β : Type v) := PersistentHashMap α β + +namespace PersistentHashMap +variables {α : Type u} {β : Type v} + +def empty : PersistentHashMap α β := {} + +instance : Inhabited (PersistentHashMap α β) := ⟨{}⟩ + +def mkEmptyEntries {α β} : Node α β := +Node.entries mkEmptyEntriesArray + +abbrev mul2Shift (i : USize) (shift : USize) : USize := USize.shift_left i shift +abbrev div2Shift (i : USize) (shift : USize) : USize := USize.shift_right i shift +abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shift_left 1 shift) - 1) + +inductive IsCollisionNode : Node α β → Prop +| mk (keys : Array α) (vals : Array β) (h : keys.size = vals.size) : IsCollisionNode (Node.collision keys vals h) + +abbrev CollisionNode (α β) := { n : Node α β // IsCollisionNode n } + +inductive IsEntriesNode : Node α β → Prop +| mk (entries : Array (Entry α β (Node α β))) : IsEntriesNode (Node.entries entries) + +abbrev EntriesNode (α β) := { n : Node α β // IsEntriesNode n } + +private theorem fsetSizeEq {ks : Array α} {vs : Array β} (h : ks.size = vs.size) (i : Fin ks.size) (j : Fin vs.size) (k : α) (v : β) + : (ks.fset i k).size = (vs.fset j v).size := +have h₁ : (ks.fset i k).size = ks.size from Array.szFSetEq _ _ _; +have h₂ : (vs.fset j v).size = vs.size from Array.szFSetEq _ _ _; +(h₁.trans h).trans h₂.symm + +private theorem pushSizeEq {ks : Array α} {vs : Array β} (h : ks.size = vs.size) (k : α) (v : β) : (ks.push k).size = (vs.push v).size := +have h₁ : (ks.push k).size = ks.size + 1 from Array.szPushEq _ _; +have h₂ : (vs.push v).size = vs.size + 1 from Array.szPushEq _ _; +have h₃ : ks.size + 1 = vs.size + 1 from h ▸ rfl; +(h₁.trans h₃).trans h₂.symm + +partial def insertAtCollisionNodeAux [HasBeq α] : CollisionNode α β → Nat → α → β → CollisionNode α β +| n@⟨Node.collision keys vals heq, _⟩ i k v := + if h : i < keys.size then + let idx : Fin keys.size := ⟨i, h⟩; + let k' := keys.fget idx; + if k == k' then + let j : Fin vals.size := ⟨i, heq ▸ h⟩; + ⟨Node.collision (keys.fset idx k) (vals.fset j v) (fsetSizeEq heq idx j k v), IsCollisionNode.mk _ _ _⟩ + else insertAtCollisionNodeAux n (i+1) k v + else + ⟨Node.collision (keys.push k) (vals.push v) (pushSizeEq heq k v), IsCollisionNode.mk _ _ _⟩ +| ⟨Node.entries _, h⟩ _ _ _ := False.elim (nomatch h) + +def insertAtCollisionNode [HasBeq α] : CollisionNode α β → α → β → CollisionNode α β := +fun n k v => insertAtCollisionNodeAux n 0 k v + +def getCollisionNodeSize : CollisionNode α β → Nat +| ⟨Node.collision keys _ _, _⟩ := keys.size +| ⟨Node.entries _, h⟩ := False.elim (nomatch h) + +def mkCollisionNode (k₁ : α) (v₁ : β) (k₂ : α) (v₂ : β) : Node α β := +let ks : Array α := Array.mkEmpty maxCollisions; +let ks := (ks.push k₁).push k₂; +let vs : Array β := Array.mkEmpty maxCollisions; +let vs := (vs.push v₁).push v₂; +Node.collision ks vs rfl + +partial def insertAux [HasBeq α] [Hashable α] : Node α β → USize → USize → α → β → Node α β +| (Node.collision keys vals heq) _ depth k v := + let newNode := insertAtCollisionNode ⟨Node.collision keys vals heq, IsCollisionNode.mk _ _ _⟩ k v; + if depth >= maxDepth || getCollisionNodeSize newNode < maxCollisions then newNode.val + else match newNode with + | ⟨Node.entries _, h⟩ => False.elim (nomatch h) + | ⟨Node.collision keys vals heq, _⟩ => + let entries : Node α β := mkEmptyEntries; + keys.iterate entries $ fun i k entries => + let v := vals.fget ⟨i.val, heq ▸ i.isLt⟩; + let h := hash k; + -- dbgTrace ("toCollision " ++ toString i ++ ", h: " ++ toString h ++ ", depth: " ++ toString depth ++ ", h': " ++ + -- toString (div2Shift h (shift * (depth - 1)))) $ fun _ => + let h := div2Shift h (shift * (depth - 1)); + insertAux entries h depth k v +| (Node.entries entries) h depth k v := + let j := (mod2Shift h shift).toNat; + Node.entries $ entries.modify j $ fun entry => + match entry with + | Entry.null => Entry.entry k v + | Entry.ref node => Entry.ref $ insertAux node (div2Shift h shift) (depth+1) k v + | Entry.entry k' v' => + if k == k' then Entry.entry k v + else Entry.ref $ mkCollisionNode k' v' k v + +def insert [HasBeq α] [Hashable α] : PersistentHashMap α β → α → β → PersistentHashMap α β +| { root := n, size := sz } k v := { root := insertAux n (hash k) 1 k v, size := sz + 1 } + +partial def findAtAux [HasBeq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) : Nat → α → Option β +| i k := + if h : i < keys.size then + let k' := keys.fget ⟨i, h⟩; + if k == k' then some (vals.fget ⟨i, heq ▸ h⟩) + else findAtAux (i+1) k + else none + +partial def findAux [HasBeq α] : Node α β → USize → α → Option β +| (Node.entries entries) h k := + let j := (mod2Shift h shift).toNat; + match entries.get j with + | Entry.null => none + | Entry.ref node => findAux node (div2Shift h shift) k + | Entry.entry k' v => if k == k' then some v else none +| (Node.collision keys vals heq) _ k := findAtAux keys vals heq 0 k + +def find [HasBeq α] [Hashable α] : PersistentHashMap α β → α → Option β +| { root := n, .. } k := findAux n (hash k) k + +section +variables {m : Type w → Type w'} [Monad m] +variables {σ : Type w} + +@[specialize] partial def mfoldlAux (f : σ → α → β → m σ) : Node α β → σ → m σ +| (Node.collision keys vals heq) acc := keys.miterate acc $ fun i k acc => f acc k (vals.fget ⟨i.val, heq ▸ i.isLt⟩) +| (Node.entries entries) acc := entries.mfoldl (fun acc entry => + match entry with + | Entry.null => pure acc + | Entry.entry k v => f acc k v + | Entry.ref node => mfoldlAux node acc) + acc + +@[specialize] def mfoldl (map : PersistentHashMap α β) (f : σ → α → β → m σ) (acc : σ) : m σ := +mfoldlAux f map.root acc + +@[specialize] def foldl (map : PersistentHashMap α β) (f : σ → α → β → σ) (acc : σ) : σ := +Id.run $ map.mfoldl f acc +end + +def toList (m : PersistentHashMap α β) : List (α × β) := +m.foldl (fun ps k v => (k, v) :: ps) [] + +structure Stats := +(numNodes : Nat := 0) +(numNull : Nat := 0) +(numCollisions : Nat := 0) +(maxDepth : Nat := 0) + +partial def collectStats : Node α β → Stats → Nat → Stats +| (Node.collision keys _ _) stats depth := + { numNodes := stats.numNodes + 1, + numCollisions := stats.numCollisions + keys.size - 1, + maxDepth := Nat.max stats.maxDepth depth, + .. stats } +| (Node.entries entries) stats depth := + let stats := + { numNodes := stats.numNodes + 1, + maxDepth := Nat.max stats.maxDepth depth, + .. stats }; + entries.foldl (fun stats entry => + match entry with + | Entry.null => { numNull := stats.numNull + 1, .. stats } + | Entry.ref node => collectStats node stats (depth + 1) + | Entry.entry _ _ => stats) + stats + +def stats (m : PersistentHashMap α β) : Stats := +collectStats m.root {} 1 + +def Stats.toString (s : Stats) : String := +"{ nodes := " ++ toString s.numNodes ++ ", null := " ++ toString s.numNull ++ +", collisions := " ++ toString s.numCollisions ++ ", depth := " ++ toString s.maxDepth ++ "}" + +instance : HasToString Stats := ⟨Stats.toString⟩ + +end PersistentHashMap diff --git a/library/init/data/persistenthashmap/default.lean b/library/init/data/persistenthashmap/default.lean new file mode 100644 index 0000000000..7bce68ee64 --- /dev/null +++ b/library/init/data/persistenthashmap/default.lean @@ -0,0 +1,7 @@ +/- +Copyright (c) 2019 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import init.data.persistenthashmap.basic diff --git a/tests/playground/phashmap.lean b/tests/playground/phashmap.lean new file mode 100644 index 0000000000..1b79a46fd2 --- /dev/null +++ b/tests/playground/phashmap.lean @@ -0,0 +1,43 @@ +import init.data.persistenthashmap +import init.lean.format +open Lean PersistentHashMap + +abbrev Map := PersistentHashMap Nat Nat + +partial def formatMap : Node Nat Nat → Format +| (Node.collision keys vals _) := Format.sbracket $ + keys.size.fold + (fun i fmt => + let k := keys.get i; + let v := vals.get i; + let p := if i > 0 then fmt ++ format "," ++ Format.line else fmt; + p ++ "c@" ++ Format.paren (format k ++ " => " ++ format v)) + Format.nil +| (Node.entries entries) := Format.sbracket $ + entries.size.fold + (fun i fmt => + let entry := entries.get i; + let p := if i > 0 then fmt ++ format "," ++ Format.line else fmt; + p ++ + match entry with + | Entry.null => "" + | Entry.ref node => formatMap node + | Entry.entry k v => Format.paren (format k ++ " => " ++ format v)) + Format.nil + +def mkMap (n : Nat) : Map := +n.fold (fun i m => m.insert i (i*10)) PersistentHashMap.empty + +def check (n : Nat) (m : Map) : IO Unit := +n.mfor $ fun i => + match m.find i with + | none => IO.println ("failed to find " ++ toString i) + | some v => unless (v == i*10) (IO.println ("unexpected value " ++ toString i ++ " => " ++ toString v)) + +def main (xs : List String) : IO Unit := +do +let n := 1000000; +let m := mkMap n; +-- IO.println (formatMap m.root); +IO.println m.stats; +check n m