lean4-htt/src/Std/Data/PersistentHashMap.lean

319 lines
13 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
namespace Std
universe u v w w'
namespace PersistentHashMap
inductive Entry (α : Type u) (β : Type v) (σ : Type w) where
| entry (key : α) (val : β) : Entry α β σ
| ref (node : σ) : Entry α β σ
| null : Entry α β σ
instance {α β σ} : Inhabited (Entry α β σ) := ⟨Entry.null⟩
inductive Node (α : Type u) (β : Type v) : Type (max u v) where
| entries (es : Array (Entry α β (Node α β))) : Node α β
| collision (ks : Array α) (vs : Array β) (h : ks.size = vs.size) : Node α β
instance {α β} : Inhabited (Node α β) := ⟨Node.entries #[]⟩
abbrev shift : USize := 5
abbrev branching : USize := USize.ofNat (2 ^ shift.toNat)
abbrev maxDepth : USize := 7
abbrev maxCollisions : Nat := 4
def mkEmptyEntriesArray {α β} : Array (Entry α β (Node α β)) :=
(Array.mkArray PersistentHashMap.branching.toNat PersistentHashMap.Entry.null)
end PersistentHashMap
structure PersistentHashMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where
root : PersistentHashMap.Node α β := PersistentHashMap.Node.entries PersistentHashMap.mkEmptyEntriesArray
size : Nat := 0
abbrev PHashMap (α : Type u) (β : Type v) [BEq α] [Hashable α] := PersistentHashMap α β
namespace PersistentHashMap
def empty [BEq α] [Hashable α] : PersistentHashMap α β := {}
def isEmpty [BEq α] [Hashable α] (m : PersistentHashMap α β) : Bool :=
m.size == 0
instance [BEq α] [Hashable α] : Inhabited (PersistentHashMap α β) := ⟨{}⟩
def mkEmptyEntries {α β} : Node α β :=
Node.entries mkEmptyEntriesArray
abbrev mul2Shift (i : USize) (shift : USize) : USize := i.shiftLeft shift
abbrev div2Shift (i : USize) (shift : USize) : USize := i.shiftRight shift
abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shiftLeft 1 shift) - 1)
inductive IsCollisionNode : Node α β → Prop where
| mk (keys : Array α) (vals : Array β) (h : keys.size = vals.size) : IsCollisionNode (Node.collision keys vals h)
abbrev CollisionNode (α β) := { n : Node α β // IsCollisionNode n }
inductive IsEntriesNode : Node α β → Prop where
| mk (entries : Array (Entry α β (Node α β))) : IsEntriesNode (Node.entries entries)
abbrev EntriesNode (α β) := { n : Node α β // IsEntriesNode n }
private theorem size_set {ks : Array α} {vs : Array β} (h : ks.size = vs.size) (i : Fin ks.size) (j : Fin vs.size) (k : α) (v : β)
: (ks.set i k).size = (vs.set j v).size := by
simp [h]
private theorem size_push {ks : Array α} {vs : Array β} (h : ks.size = vs.size) (k : α) (v : β) : (ks.push k).size = (vs.push v).size := by
simp [h]
partial def insertAtCollisionNodeAux [BEq α] : CollisionNode α β → Nat → α → β → CollisionNode α β
| n@⟨Node.collision keys vals heq, _⟩, i, k, v =>
if h : i < keys.size then
let idx : Fin keys.size := ⟨i, h⟩;
let k' := keys.get idx;
if k == k' then
let j : Fin vals.size := ⟨i, by rw [←heq]; assumption⟩
⟨Node.collision (keys.set idx k) (vals.set j v) (size_set heq idx j k v), IsCollisionNode.mk _ _ _⟩
else insertAtCollisionNodeAux n (i+1) k v
else
⟨Node.collision (keys.push k) (vals.push v) (size_push heq k v), IsCollisionNode.mk _ _ _⟩
| ⟨Node.entries _, h⟩, _, _, _ => False.elim (nomatch h)
def insertAtCollisionNode [BEq α] : CollisionNode α β → α → β → CollisionNode α β :=
fun n k v => insertAtCollisionNodeAux n 0 k v
def getCollisionNodeSize : CollisionNode α β → Nat
| ⟨Node.collision keys _ _, _⟩ => keys.size
| ⟨Node.entries _, h⟩ => False.elim (nomatch h)
def mkCollisionNode (k₁ : α) (v₁ : β) (k₂ : α) (v₂ : β) : Node α β :=
let ks : Array α := Array.mkEmpty maxCollisions
let ks := (ks.push k₁).push k₂
let vs : Array β := Array.mkEmpty maxCollisions
let vs := (vs.push v₁).push v₂
Node.collision ks vs rfl
partial def insertAux [BEq α] [Hashable α] : Node α β → USize → USize → α → β → Node α β
| Node.collision keys vals heq, _, depth, k, v =>
let newNode := insertAtCollisionNode ⟨Node.collision keys vals heq, IsCollisionNode.mk _ _ _⟩ k v
if depth >= maxDepth || getCollisionNodeSize newNode < maxCollisions then newNode.val
else match newNode with
| ⟨Node.entries _, h⟩ => False.elim (nomatch h)
| ⟨Node.collision keys vals heq, _⟩ =>
let rec traverse (i : Nat) (entries : Node α β) : Node α β :=
if h : i < keys.size then
let k := keys[i]
let v := vals[i]'(heq ▸ h)
let h := hash k |>.toUSize
let h := div2Shift h (shift * (depth - 1))
traverse (i+1) (insertAux entries h depth k v)
else
entries
traverse 0 mkEmptyEntries
| Node.entries entries, h, depth, k, v =>
let j := (mod2Shift h shift).toNat
Node.entries $ entries.modify j fun entry =>
match entry with
| Entry.null => Entry.entry k v
| Entry.ref node => Entry.ref $ insertAux node (div2Shift h shift) (depth+1) k v
| Entry.entry k' v' =>
if k == k' then Entry.entry k v
else Entry.ref $ mkCollisionNode k' v' k v
def insert {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → β → PersistentHashMap α β
| { root := n, size := sz }, k, v => { root := insertAux n (hash k |>.toUSize) 1 k v, size := sz + 1 }
partial def findAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Option β :=
if h : i < keys.size then
let k' := keys[i]
if k == k' then some (vals[i]'(by rw [←heq]; assumption))
else findAtAux keys vals heq (i+1) k
else none
partial def findAux [BEq α] : Node α β → USize → α → Option β
| Node.entries entries, h, k =>
let j := (mod2Shift h shift).toNat
match entries.get! j with
| Entry.null => none
| Entry.ref node => findAux node (div2Shift h shift) k
| Entry.entry k' v => if k == k' then some v else none
| Node.collision keys vals heq, _, k => findAtAux keys vals heq 0 k
def find? {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → Option β
| { root := n, .. }, k => findAux n (hash k |>.toUSize) k
@[inline] def getOp {_ : BEq α} {_ : Hashable α} (self : PersistentHashMap α β) (idx : α) : Option β :=
self.find? idx
instance {_ : BEq α} {_ : Hashable α} : GetElem (PersistentHashMap α β) α (Option β) fun _ _ => True where
getElem m i _ := m.find? i
@[inline] def findD {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) (a : α) (b₀ : β) : β :=
(m.find? a).getD b₀
@[inline] def find! {_ : BEq α} {_ : Hashable α} [Inhabited β] (m : PersistentHashMap α β) (a : α) : β :=
match m.find? a with
| some b => b
| none => panic! "key is not in the map"
partial def findEntryAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Option (α × β) :=
if h : i < keys.size then
let k' := keys[i]
if k == k' then some (k', vals[i]'(by rw [←heq]; assumption))
else findEntryAtAux keys vals heq (i+1) k
else none
partial def findEntryAux [BEq α] : Node α β → USize → α → Option (α × β)
| Node.entries entries, h, k =>
let j := (mod2Shift h shift).toNat
match entries.get! j with
| Entry.null => none
| Entry.ref node => findEntryAux node (div2Shift h shift) k
| Entry.entry k' v => if k == k' then some (k', v) else none
| Node.collision keys vals heq, _, k => findEntryAtAux keys vals heq 0 k
def findEntry? {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → Option (α × β)
| { root := n, .. }, k => findEntryAux n (hash k |>.toUSize) k
partial def containsAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Bool :=
if h : i < keys.size then
let k' := keys[i]
if k == k' then true
else containsAtAux keys vals heq (i+1) k
else false
partial def containsAux [BEq α] : Node α β → USize → α → Bool
| Node.entries entries, h, k =>
let j := (mod2Shift h shift).toNat
match entries.get! j with
| Entry.null => false
| Entry.ref node => containsAux node (div2Shift h shift) k
| Entry.entry k' _ => k == k'
| Node.collision keys vals heq, _, k => containsAtAux keys vals heq 0 k
def contains [BEq α] [Hashable α] : PersistentHashMap α β → α → Bool
| { root := n, .. }, k => containsAux n (hash k |>.toUSize) k
partial def isUnaryEntries (a : Array (Entry α β (Node α β))) (i : Nat) (acc : Option (α × β)) : Option (α × β) :=
if h : i < a.size then
match a[i] with
| Entry.null => isUnaryEntries a (i+1) acc
| Entry.ref _ => none
| Entry.entry k v =>
match acc with
| none => isUnaryEntries a (i+1) (some (k, v))
| some _ => none
else acc
def isUnaryNode : Node α β → Option (α × β)
| Node.entries entries => isUnaryEntries entries 0 none
| Node.collision keys vals heq =>
if h : 1 = keys.size then
have : 0 < keys.size := by rw [←h]; decide
have : 0 < vals.size := by rw [←heq]; assumption
some (keys[0], vals[0])
else
none
partial def eraseAux [BEq α] : Node α β → USize → α → Node α β × Bool
| n@(Node.collision keys vals heq), _, k =>
match keys.indexOf? k with
| some idx =>
let ⟨keys', keq⟩ := keys.eraseIdx' idx
let ⟨vals', veq⟩ := vals.eraseIdx' (Eq.ndrec idx heq)
have : keys.size - 1 = vals.size - 1 := by rw [heq]
(Node.collision keys' vals' (keq.trans (this.trans veq.symm)), true)
| none => (n, false)
| n@(Node.entries entries), h, k =>
let j := (mod2Shift h shift).toNat
let entry := entries.get! j
match entry with
| Entry.null => (n, false)
| Entry.entry k' _ =>
if k == k' then (Node.entries (entries.set! j Entry.null), true) else (n, false)
| Entry.ref node =>
let entries := entries.set! j Entry.null
let (newNode, deleted) := eraseAux node (div2Shift h shift) k
if !deleted then (n, false)
else match isUnaryNode newNode with
| none => (Node.entries (entries.set! j (Entry.ref newNode)), true)
| some (k, v) => (Node.entries (entries.set! j (Entry.entry k v)), true)
def erase {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → PersistentHashMap α β
| { root := n, size := sz }, k =>
let h := hash k |>.toUSize
let (n, del) := eraseAux n h k
{ root := n, size := if del then sz - 1 else sz }
section
variable {m : Type w → Type w'} [Monad m]
variable {σ : Type w}
@[specialize] partial def foldlMAux (f : σα → β → m σ) : Node α β → σ → m σ
| Node.collision keys vals heq, acc =>
let rec traverse (i : Nat) (acc : σ) : m σ := do
if h : i < keys.size then
let k := keys[i]
let v := vals[i]'(heq ▸ h)
traverse (i+1) (← f acc k v)
else
pure acc
traverse 0 acc
| Node.entries entries, acc => entries.foldlM (fun acc entry =>
match entry with
| Entry.null => pure acc
| Entry.entry k v => f acc k v
| Entry.ref node => foldlMAux f node acc)
acc
@[specialize] def foldlM {_ : BEq α} {_ : Hashable α} (map : PersistentHashMap α β) (f : σα → β → m σ) (init : σ) : m σ :=
foldlMAux f map.root init
@[specialize] def forM {_ : BEq α} {_ : Hashable α} (map : PersistentHashMap α β) (f : α → β → m PUnit) : m PUnit :=
map.foldlM (fun _ => f) ⟨⟩
@[specialize] def foldl {_ : BEq α} {_ : Hashable α} (map : PersistentHashMap α β) (f : σα → β → σ) (init : σ) : σ :=
Id.run $ map.foldlM f init
end
def toList {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) : List (α × β) :=
m.foldl (init := []) fun ps k v => (k, v) :: ps
structure Stats where
numNodes : Nat := 0
numNull : Nat := 0
numCollisions : Nat := 0
maxDepth : Nat := 0
partial def collectStats : Node α β → Stats → Nat → Stats
| Node.collision keys _ _, stats, depth =>
{ stats with
numNodes := stats.numNodes + 1,
numCollisions := stats.numCollisions + keys.size - 1,
maxDepth := Nat.max stats.maxDepth depth }
| Node.entries entries, stats, depth =>
let stats :=
{ stats with
numNodes := stats.numNodes + 1,
maxDepth := Nat.max stats.maxDepth depth }
entries.foldl (fun stats entry =>
match entry with
| Entry.null => { stats with numNull := stats.numNull + 1 }
| Entry.ref node => collectStats node stats (depth + 1)
| Entry.entry _ _ => stats)
stats
def stats {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) : Stats :=
collectStats m.root {} 1
def Stats.toString (s : Stats) : String :=
s!"\{ nodes := {s.numNodes}, null := {s.numNull}, collisions := {s.numCollisions}, depth := {s.maxDepth}}"
instance : ToString Stats := ⟨Stats.toString⟩
end PersistentHashMap
end Std