feat(library/init/data): add PersistentHashMap

This commit is contained in:
Leonardo de Moura 2019-08-02 09:08:24 -07:00
parent 25d8c76910
commit c371b43970
4 changed files with 262 additions and 0 deletions

View file

@ -92,6 +92,9 @@ def fset (a : Array α) (i : @& Fin a.size) (v : α) : Array α :=
theorem szFSetEq (a : Array α) (i : Fin a.size) (v : α) : (fset a i v).size = a.size :=
rfl
theorem szPushEq (a : Array α) (v : α) : (push a v).size = a.size + 1 :=
rfl
/- Low-level version of `fset` which is as fast as a C array fset.
`Fin` values are represented as tag pointers in the Lean runtime. Thus,
`fset` may be slightly slower than `uset`. -/

View file

@ -0,0 +1,209 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import init.data.array
import init.data.hashable
universes u v w w'
namespace PersistentHashMap
inductive Entry (α : Type u) (β : Type v) (σ : Type w)
| entry {} (key : α) (val : β) : Entry
| ref {} (node : σ) : Entry
| null {} : Entry
instance Entry.inhabited {α β σ} : Inhabited (Entry α β σ) := ⟨Entry.null⟩
inductive Node (α : Type u) (β : Type v) : Type (max u v)
| entries (es : Array (Entry α β Node)) : Node
| collision (ks : Array α) (vs : Array β) (h : ks.size = vs.size) : Node
instance Node.inhabited {α β} : Inhabited (Node α β) := ⟨Node.entries Array.empty⟩
abbrev shift : USize := 5
abbrev branching : USize := USize.ofNat (2 ^ shift.toNat)
abbrev maxDepth : USize := 7
abbrev maxCollisions : Nat := 4
def mkEmptyEntriesArray {α β} : Array (Entry α β (Node α β)) :=
(Array.mkArray PersistentHashMap.branching.toNat PersistentHashMap.Entry.null)
end PersistentHashMap
structure PersistentHashMap (α : Type u) (β : Type v) :=
(root : PersistentHashMap.Node α β := PersistentHashMap.Node.entries PersistentHashMap.mkEmptyEntriesArray)
(size : Nat := 0)
abbrev PHashMap (α : Type u) (β : Type v) := PersistentHashMap α β
namespace PersistentHashMap
variables {α : Type u} {β : Type v}
def empty : PersistentHashMap α β := {}
instance : Inhabited (PersistentHashMap α β) := ⟨{}⟩
def mkEmptyEntries {α β} : Node α β :=
Node.entries mkEmptyEntriesArray
abbrev mul2Shift (i : USize) (shift : USize) : USize := USize.shift_left i shift
abbrev div2Shift (i : USize) (shift : USize) : USize := USize.shift_right i shift
abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shift_left 1 shift) - 1)
inductive IsCollisionNode : Node α β → Prop
| mk (keys : Array α) (vals : Array β) (h : keys.size = vals.size) : IsCollisionNode (Node.collision keys vals h)
abbrev CollisionNode (α β) := { n : Node α β // IsCollisionNode n }
inductive IsEntriesNode : Node α β → Prop
| mk (entries : Array (Entry α β (Node α β))) : IsEntriesNode (Node.entries entries)
abbrev EntriesNode (α β) := { n : Node α β // IsEntriesNode n }
private theorem fsetSizeEq {ks : Array α} {vs : Array β} (h : ks.size = vs.size) (i : Fin ks.size) (j : Fin vs.size) (k : α) (v : β)
: (ks.fset i k).size = (vs.fset j v).size :=
have h₁ : (ks.fset i k).size = ks.size from Array.szFSetEq _ _ _;
have h₂ : (vs.fset j v).size = vs.size from Array.szFSetEq _ _ _;
(h₁.trans h).trans h₂.symm
private theorem pushSizeEq {ks : Array α} {vs : Array β} (h : ks.size = vs.size) (k : α) (v : β) : (ks.push k).size = (vs.push v).size :=
have h₁ : (ks.push k).size = ks.size + 1 from Array.szPushEq _ _;
have h₂ : (vs.push v).size = vs.size + 1 from Array.szPushEq _ _;
have h₃ : ks.size + 1 = vs.size + 1 from h ▸ rfl;
(h₁.trans h₃).trans h₂.symm
partial def insertAtCollisionNodeAux [HasBeq α] : CollisionNode α β → Nat → α → β → CollisionNode α β
| n@⟨Node.collision keys vals heq, _⟩ i k v :=
if h : i < keys.size then
let idx : Fin keys.size := ⟨i, h⟩;
let k' := keys.fget idx;
if k == k' then
let j : Fin vals.size := ⟨i, heq ▸ h⟩;
⟨Node.collision (keys.fset idx k) (vals.fset j v) (fsetSizeEq heq idx j k v), IsCollisionNode.mk _ _ _⟩
else insertAtCollisionNodeAux n (i+1) k v
else
⟨Node.collision (keys.push k) (vals.push v) (pushSizeEq heq k v), IsCollisionNode.mk _ _ _⟩
| ⟨Node.entries _, h⟩ _ _ _ := False.elim (nomatch h)
def insertAtCollisionNode [HasBeq α] : CollisionNode α β → α → β → CollisionNode α β :=
fun n k v => insertAtCollisionNodeAux n 0 k v
def getCollisionNodeSize : CollisionNode α β → Nat
| ⟨Node.collision keys _ _, _⟩ := keys.size
| ⟨Node.entries _, h⟩ := False.elim (nomatch h)
def mkCollisionNode (k₁ : α) (v₁ : β) (k₂ : α) (v₂ : β) : Node α β :=
let ks : Array α := Array.mkEmpty maxCollisions;
let ks := (ks.push k₁).push k₂;
let vs : Array β := Array.mkEmpty maxCollisions;
let vs := (vs.push v₁).push v₂;
Node.collision ks vs rfl
partial def insertAux [HasBeq α] [Hashable α] : Node α β → USize → USize → α → β → Node α β
| (Node.collision keys vals heq) _ depth k v :=
let newNode := insertAtCollisionNode ⟨Node.collision keys vals heq, IsCollisionNode.mk _ _ _⟩ k v;
if depth >= maxDepth || getCollisionNodeSize newNode < maxCollisions then newNode.val
else match newNode with
| ⟨Node.entries _, h⟩ => False.elim (nomatch h)
| ⟨Node.collision keys vals heq, _⟩ =>
let entries : Node α β := mkEmptyEntries;
keys.iterate entries $ fun i k entries =>
let v := vals.fget ⟨i.val, heq ▸ i.isLt⟩;
let h := hash k;
-- dbgTrace ("toCollision " ++ toString i ++ ", h: " ++ toString h ++ ", depth: " ++ toString depth ++ ", h': " ++
-- toString (div2Shift h (shift * (depth - 1)))) $ fun _ =>
let h := div2Shift h (shift * (depth - 1));
insertAux entries h depth k v
| (Node.entries entries) h depth k v :=
let j := (mod2Shift h shift).toNat;
Node.entries $ entries.modify j $ fun entry =>
match entry with
| Entry.null => Entry.entry k v
| Entry.ref node => Entry.ref $ insertAux node (div2Shift h shift) (depth+1) k v
| Entry.entry k' v' =>
if k == k' then Entry.entry k v
else Entry.ref $ mkCollisionNode k' v' k v
def insert [HasBeq α] [Hashable α] : PersistentHashMap α β → α → β → PersistentHashMap α β
| { root := n, size := sz } k v := { root := insertAux n (hash k) 1 k v, size := sz + 1 }
partial def findAtAux [HasBeq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) : Nat → α → Option β
| i k :=
if h : i < keys.size then
let k' := keys.fget ⟨i, h⟩;
if k == k' then some (vals.fget ⟨i, heq ▸ h⟩)
else findAtAux (i+1) k
else none
partial def findAux [HasBeq α] : Node α β → USize → α → Option β
| (Node.entries entries) h k :=
let j := (mod2Shift h shift).toNat;
match entries.get j with
| Entry.null => none
| Entry.ref node => findAux node (div2Shift h shift) k
| Entry.entry k' v => if k == k' then some v else none
| (Node.collision keys vals heq) _ k := findAtAux keys vals heq 0 k
def find [HasBeq α] [Hashable α] : PersistentHashMap α β → α → Option β
| { root := n, .. } k := findAux n (hash k) k
section
variables {m : Type w → Type w'} [Monad m]
variables {σ : Type w}
@[specialize] partial def mfoldlAux (f : σα → β → m σ) : Node α β → σ → m σ
| (Node.collision keys vals heq) acc := keys.miterate acc $ fun i k acc => f acc k (vals.fget ⟨i.val, heq ▸ i.isLt⟩)
| (Node.entries entries) acc := entries.mfoldl (fun acc entry =>
match entry with
| Entry.null => pure acc
| Entry.entry k v => f acc k v
| Entry.ref node => mfoldlAux node acc)
acc
@[specialize] def mfoldl (map : PersistentHashMap α β) (f : σα → β → m σ) (acc : σ) : m σ :=
mfoldlAux f map.root acc
@[specialize] def foldl (map : PersistentHashMap α β) (f : σα → β → σ) (acc : σ) : σ :=
Id.run $ map.mfoldl f acc
end
def toList (m : PersistentHashMap α β) : List (α × β) :=
m.foldl (fun ps k v => (k, v) :: ps) []
structure Stats :=
(numNodes : Nat := 0)
(numNull : Nat := 0)
(numCollisions : Nat := 0)
(maxDepth : Nat := 0)
partial def collectStats : Node α β → Stats → Nat → Stats
| (Node.collision keys _ _) stats depth :=
{ numNodes := stats.numNodes + 1,
numCollisions := stats.numCollisions + keys.size - 1,
maxDepth := Nat.max stats.maxDepth depth,
.. stats }
| (Node.entries entries) stats depth :=
let stats :=
{ numNodes := stats.numNodes + 1,
maxDepth := Nat.max stats.maxDepth depth,
.. stats };
entries.foldl (fun stats entry =>
match entry with
| Entry.null => { numNull := stats.numNull + 1, .. stats }
| Entry.ref node => collectStats node stats (depth + 1)
| Entry.entry _ _ => stats)
stats
def stats (m : PersistentHashMap α β) : Stats :=
collectStats m.root {} 1
def Stats.toString (s : Stats) : String :=
"{ nodes := " ++ toString s.numNodes ++ ", null := " ++ toString s.numNull ++
", collisions := " ++ toString s.numCollisions ++ ", depth := " ++ toString s.maxDepth ++ "}"
instance : HasToString Stats := ⟨Stats.toString⟩
end PersistentHashMap

View file

@ -0,0 +1,7 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import init.data.persistenthashmap.basic

View file

@ -0,0 +1,43 @@
import init.data.persistenthashmap
import init.lean.format
open Lean PersistentHashMap
abbrev Map := PersistentHashMap Nat Nat
partial def formatMap : Node Nat Nat → Format
| (Node.collision keys vals _) := Format.sbracket $
keys.size.fold
(fun i fmt =>
let k := keys.get i;
let v := vals.get i;
let p := if i > 0 then fmt ++ format "," ++ Format.line else fmt;
p ++ "c@" ++ Format.paren (format k ++ " => " ++ format v))
Format.nil
| (Node.entries entries) := Format.sbracket $
entries.size.fold
(fun i fmt =>
let entry := entries.get i;
let p := if i > 0 then fmt ++ format "," ++ Format.line else fmt;
p ++
match entry with
| Entry.null => "<null>"
| Entry.ref node => formatMap node
| Entry.entry k v => Format.paren (format k ++ " => " ++ format v))
Format.nil
def mkMap (n : Nat) : Map :=
n.fold (fun i m => m.insert i (i*10)) PersistentHashMap.empty
def check (n : Nat) (m : Map) : IO Unit :=
n.mfor $ fun i =>
match m.find i with
| none => IO.println ("failed to find " ++ toString i)
| some v => unless (v == i*10) (IO.println ("unexpected value " ++ toString i ++ " => " ++ toString v))
def main (xs : List String) : IO Unit :=
do
let n := 1000000;
let m := mkMap n;
-- IO.println (formatMap m.root);
IO.println m.stats;
check n m