/- Copyright (c) 2019 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ namespace Std universe u v w w' namespace PersistentHashMap inductive Entry (α : Type u) (β : Type v) (σ : Type w) where | entry (key : α) (val : β) : Entry α β σ | ref (node : σ) : Entry α β σ | null : Entry α β σ instance {α β σ} : Inhabited (Entry α β σ) := ⟨Entry.null⟩ inductive Node (α : Type u) (β : Type v) : Type (max u v) where | entries (es : Array (Entry α β (Node α β))) : Node α β | collision (ks : Array α) (vs : Array β) (h : ks.size = vs.size) : Node α β instance {α β} : Inhabited (Node α β) := ⟨Node.entries #[]⟩ abbrev shift : USize := 5 abbrev branching : USize := USize.ofNat (2 ^ shift.toNat) abbrev maxDepth : USize := 7 abbrev maxCollisions : Nat := 4 def mkEmptyEntriesArray {α β} : Array (Entry α β (Node α β)) := (Array.mkArray PersistentHashMap.branching.toNat PersistentHashMap.Entry.null) end PersistentHashMap structure PersistentHashMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where root : PersistentHashMap.Node α β := PersistentHashMap.Node.entries PersistentHashMap.mkEmptyEntriesArray size : Nat := 0 abbrev PHashMap (α : Type u) (β : Type v) [BEq α] [Hashable α] := PersistentHashMap α β namespace PersistentHashMap def empty [BEq α] [Hashable α] : PersistentHashMap α β := {} def isEmpty [BEq α] [Hashable α] (m : PersistentHashMap α β) : Bool := m.size == 0 instance [BEq α] [Hashable α] : Inhabited (PersistentHashMap α β) := ⟨{}⟩ def mkEmptyEntries {α β} : Node α β := Node.entries mkEmptyEntriesArray abbrev mul2Shift (i : USize) (shift : USize) : USize := i.shiftLeft shift abbrev div2Shift (i : USize) (shift : USize) : USize := i.shiftRight shift abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shiftLeft 1 shift) - 1) inductive IsCollisionNode : Node α β → Prop where | mk (keys : Array α) (vals : Array β) (h : keys.size = vals.size) : IsCollisionNode (Node.collision keys vals h) abbrev CollisionNode (α β) := { n : Node α β // IsCollisionNode n } inductive IsEntriesNode : Node α β → Prop where | mk (entries : Array (Entry α β (Node α β))) : IsEntriesNode (Node.entries entries) abbrev EntriesNode (α β) := { n : Node α β // IsEntriesNode n } private theorem size_set {ks : Array α} {vs : Array β} (h : ks.size = vs.size) (i : Fin ks.size) (j : Fin vs.size) (k : α) (v : β) : (ks.set i k).size = (vs.set j v).size := by simp [h] private theorem size_push {ks : Array α} {vs : Array β} (h : ks.size = vs.size) (k : α) (v : β) : (ks.push k).size = (vs.push v).size := by simp [h] partial def insertAtCollisionNodeAux [BEq α] : CollisionNode α β → Nat → α → β → CollisionNode α β | n@⟨Node.collision keys vals heq, _⟩, i, k, v => if h : i < keys.size then let idx : Fin keys.size := ⟨i, h⟩; let k' := keys.get idx; if k == k' then let j : Fin vals.size := ⟨i, by rw [←heq]; assumption⟩ ⟨Node.collision (keys.set idx k) (vals.set j v) (size_set heq idx j k v), IsCollisionNode.mk _ _ _⟩ else insertAtCollisionNodeAux n (i+1) k v else ⟨Node.collision (keys.push k) (vals.push v) (size_push heq k v), IsCollisionNode.mk _ _ _⟩ | ⟨Node.entries _, h⟩, _, _, _ => False.elim (nomatch h) def insertAtCollisionNode [BEq α] : CollisionNode α β → α → β → CollisionNode α β := fun n k v => insertAtCollisionNodeAux n 0 k v def getCollisionNodeSize : CollisionNode α β → Nat | ⟨Node.collision keys _ _, _⟩ => keys.size | ⟨Node.entries _, h⟩ => False.elim (nomatch h) def mkCollisionNode (k₁ : α) (v₁ : β) (k₂ : α) (v₂ : β) : Node α β := let ks : Array α := Array.mkEmpty maxCollisions let ks := (ks.push k₁).push k₂ let vs : Array β := Array.mkEmpty maxCollisions let vs := (vs.push v₁).push v₂ Node.collision ks vs rfl partial def insertAux [BEq α] [Hashable α] : Node α β → USize → USize → α → β → Node α β | Node.collision keys vals heq, _, depth, k, v => let newNode := insertAtCollisionNode ⟨Node.collision keys vals heq, IsCollisionNode.mk _ _ _⟩ k v if depth >= maxDepth || getCollisionNodeSize newNode < maxCollisions then newNode.val else match newNode with | ⟨Node.entries _, h⟩ => False.elim (nomatch h) | ⟨Node.collision keys vals heq, _⟩ => let rec traverse (i : Nat) (entries : Node α β) : Node α β := if h : i < keys.size then let k := keys[i] let v := vals[i]'(heq ▸ h) let h := hash k |>.toUSize let h := div2Shift h (shift * (depth - 1)) traverse (i+1) (insertAux entries h depth k v) else entries traverse 0 mkEmptyEntries | Node.entries entries, h, depth, k, v => let j := (mod2Shift h shift).toNat Node.entries $ entries.modify j fun entry => match entry with | Entry.null => Entry.entry k v | Entry.ref node => Entry.ref $ insertAux node (div2Shift h shift) (depth+1) k v | Entry.entry k' v' => if k == k' then Entry.entry k v else Entry.ref $ mkCollisionNode k' v' k v def insert {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → β → PersistentHashMap α β | { root := n, size := sz }, k, v => { root := insertAux n (hash k |>.toUSize) 1 k v, size := sz + 1 } partial def findAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Option β := if h : i < keys.size then let k' := keys[i] if k == k' then some (vals[i]'(by rw [←heq]; assumption)) else findAtAux keys vals heq (i+1) k else none partial def findAux [BEq α] : Node α β → USize → α → Option β | Node.entries entries, h, k => let j := (mod2Shift h shift).toNat match entries.get! j with | Entry.null => none | Entry.ref node => findAux node (div2Shift h shift) k | Entry.entry k' v => if k == k' then some v else none | Node.collision keys vals heq, _, k => findAtAux keys vals heq 0 k def find? {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → Option β | { root := n, .. }, k => findAux n (hash k |>.toUSize) k @[inline] def getOp {_ : BEq α} {_ : Hashable α} (self : PersistentHashMap α β) (idx : α) : Option β := self.find? idx instance {_ : BEq α} {_ : Hashable α} : GetElem (PersistentHashMap α β) α (Option β) fun _ _ => True where getElem m i _ := m.find? i @[inline] def findD {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) (a : α) (b₀ : β) : β := (m.find? a).getD b₀ @[inline] def find! {_ : BEq α} {_ : Hashable α} [Inhabited β] (m : PersistentHashMap α β) (a : α) : β := match m.find? a with | some b => b | none => panic! "key is not in the map" partial def findEntryAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Option (α × β) := if h : i < keys.size then let k' := keys[i] if k == k' then some (k', vals[i]'(by rw [←heq]; assumption)) else findEntryAtAux keys vals heq (i+1) k else none partial def findEntryAux [BEq α] : Node α β → USize → α → Option (α × β) | Node.entries entries, h, k => let j := (mod2Shift h shift).toNat match entries.get! j with | Entry.null => none | Entry.ref node => findEntryAux node (div2Shift h shift) k | Entry.entry k' v => if k == k' then some (k', v) else none | Node.collision keys vals heq, _, k => findEntryAtAux keys vals heq 0 k def findEntry? {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → Option (α × β) | { root := n, .. }, k => findEntryAux n (hash k |>.toUSize) k partial def containsAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Bool := if h : i < keys.size then let k' := keys[i] if k == k' then true else containsAtAux keys vals heq (i+1) k else false partial def containsAux [BEq α] : Node α β → USize → α → Bool | Node.entries entries, h, k => let j := (mod2Shift h shift).toNat match entries.get! j with | Entry.null => false | Entry.ref node => containsAux node (div2Shift h shift) k | Entry.entry k' _ => k == k' | Node.collision keys vals heq, _, k => containsAtAux keys vals heq 0 k def contains [BEq α] [Hashable α] : PersistentHashMap α β → α → Bool | { root := n, .. }, k => containsAux n (hash k |>.toUSize) k partial def isUnaryEntries (a : Array (Entry α β (Node α β))) (i : Nat) (acc : Option (α × β)) : Option (α × β) := if h : i < a.size then match a[i] with | Entry.null => isUnaryEntries a (i+1) acc | Entry.ref _ => none | Entry.entry k v => match acc with | none => isUnaryEntries a (i+1) (some (k, v)) | some _ => none else acc def isUnaryNode : Node α β → Option (α × β) | Node.entries entries => isUnaryEntries entries 0 none | Node.collision keys vals heq => if h : 1 = keys.size then have : 0 < keys.size := by rw [←h]; decide have : 0 < vals.size := by rw [←heq]; assumption some (keys[0], vals[0]) else none partial def eraseAux [BEq α] : Node α β → USize → α → Node α β × Bool | n@(Node.collision keys vals heq), _, k => match keys.indexOf? k with | some idx => let ⟨keys', keq⟩ := keys.eraseIdx' idx let ⟨vals', veq⟩ := vals.eraseIdx' (Eq.ndrec idx heq) have : keys.size - 1 = vals.size - 1 := by rw [heq] (Node.collision keys' vals' (keq.trans (this.trans veq.symm)), true) | none => (n, false) | n@(Node.entries entries), h, k => let j := (mod2Shift h shift).toNat let entry := entries.get! j match entry with | Entry.null => (n, false) | Entry.entry k' _ => if k == k' then (Node.entries (entries.set! j Entry.null), true) else (n, false) | Entry.ref node => let entries := entries.set! j Entry.null let (newNode, deleted) := eraseAux node (div2Shift h shift) k if !deleted then (n, false) else match isUnaryNode newNode with | none => (Node.entries (entries.set! j (Entry.ref newNode)), true) | some (k, v) => (Node.entries (entries.set! j (Entry.entry k v)), true) def erase {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → PersistentHashMap α β | { root := n, size := sz }, k => let h := hash k |>.toUSize let (n, del) := eraseAux n h k { root := n, size := if del then sz - 1 else sz } section variable {m : Type w → Type w'} [Monad m] variable {σ : Type w} @[specialize] partial def foldlMAux (f : σ → α → β → m σ) : Node α β → σ → m σ | Node.collision keys vals heq, acc => let rec traverse (i : Nat) (acc : σ) : m σ := do if h : i < keys.size then let k := keys[i] let v := vals[i]'(heq ▸ h) traverse (i+1) (← f acc k v) else pure acc traverse 0 acc | Node.entries entries, acc => entries.foldlM (fun acc entry => match entry with | Entry.null => pure acc | Entry.entry k v => f acc k v | Entry.ref node => foldlMAux f node acc) acc @[specialize] def foldlM {_ : BEq α} {_ : Hashable α} (map : PersistentHashMap α β) (f : σ → α → β → m σ) (init : σ) : m σ := foldlMAux f map.root init @[specialize] def forM {_ : BEq α} {_ : Hashable α} (map : PersistentHashMap α β) (f : α → β → m PUnit) : m PUnit := map.foldlM (fun _ => f) ⟨⟩ @[specialize] def foldl {_ : BEq α} {_ : Hashable α} (map : PersistentHashMap α β) (f : σ → α → β → σ) (init : σ) : σ := Id.run $ map.foldlM f init end def toList {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) : List (α × β) := m.foldl (init := []) fun ps k v => (k, v) :: ps structure Stats where numNodes : Nat := 0 numNull : Nat := 0 numCollisions : Nat := 0 maxDepth : Nat := 0 partial def collectStats : Node α β → Stats → Nat → Stats | Node.collision keys _ _, stats, depth => { stats with numNodes := stats.numNodes + 1, numCollisions := stats.numCollisions + keys.size - 1, maxDepth := Nat.max stats.maxDepth depth } | Node.entries entries, stats, depth => let stats := { stats with numNodes := stats.numNodes + 1, maxDepth := Nat.max stats.maxDepth depth } entries.foldl (fun stats entry => match entry with | Entry.null => { stats with numNull := stats.numNull + 1 } | Entry.ref node => collectStats node stats (depth + 1) | Entry.entry _ _ => stats) stats def stats {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) : Stats := collectStats m.root {} 1 def Stats.toString (s : Stats) : String := s!"\{ nodes := {s.numNodes}, null := {s.numNull}, collisions := {s.numCollisions}, depth := {s.maxDepth}}" instance : ToString Stats := ⟨Stats.toString⟩ end PersistentHashMap end Std