feat(library/init/data): add PersistentHashMap
This commit is contained in:
parent
25d8c76910
commit
c371b43970
4 changed files with 262 additions and 0 deletions
|
|
@ -92,6 +92,9 @@ def fset (a : Array α) (i : @& Fin a.size) (v : α) : Array α :=
|
|||
theorem szFSetEq (a : Array α) (i : Fin a.size) (v : α) : (fset a i v).size = a.size :=
|
||||
rfl
|
||||
|
||||
theorem szPushEq (a : Array α) (v : α) : (push a v).size = a.size + 1 :=
|
||||
rfl
|
||||
|
||||
/- Low-level version of `fset` which is as fast as a C array fset.
|
||||
`Fin` values are represented as tag pointers in the Lean runtime. Thus,
|
||||
`fset` may be slightly slower than `uset`. -/
|
||||
|
|
|
|||
209
library/init/data/persistenthashmap/basic.lean
Normal file
209
library/init/data/persistenthashmap/basic.lean
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
/-
|
||||
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import init.data.array
|
||||
import init.data.hashable
|
||||
universes u v w w'
|
||||
|
||||
namespace PersistentHashMap
|
||||
|
||||
inductive Entry (α : Type u) (β : Type v) (σ : Type w)
|
||||
| entry {} (key : α) (val : β) : Entry
|
||||
| ref {} (node : σ) : Entry
|
||||
| null {} : Entry
|
||||
|
||||
instance Entry.inhabited {α β σ} : Inhabited (Entry α β σ) := ⟨Entry.null⟩
|
||||
|
||||
inductive Node (α : Type u) (β : Type v) : Type (max u v)
|
||||
| entries (es : Array (Entry α β Node)) : Node
|
||||
| collision (ks : Array α) (vs : Array β) (h : ks.size = vs.size) : Node
|
||||
|
||||
instance Node.inhabited {α β} : Inhabited (Node α β) := ⟨Node.entries Array.empty⟩
|
||||
|
||||
abbrev shift : USize := 5
|
||||
abbrev branching : USize := USize.ofNat (2 ^ shift.toNat)
|
||||
abbrev maxDepth : USize := 7
|
||||
abbrev maxCollisions : Nat := 4
|
||||
|
||||
def mkEmptyEntriesArray {α β} : Array (Entry α β (Node α β)) :=
|
||||
(Array.mkArray PersistentHashMap.branching.toNat PersistentHashMap.Entry.null)
|
||||
|
||||
end PersistentHashMap
|
||||
|
||||
structure PersistentHashMap (α : Type u) (β : Type v) :=
|
||||
(root : PersistentHashMap.Node α β := PersistentHashMap.Node.entries PersistentHashMap.mkEmptyEntriesArray)
|
||||
(size : Nat := 0)
|
||||
|
||||
abbrev PHashMap (α : Type u) (β : Type v) := PersistentHashMap α β
|
||||
|
||||
namespace PersistentHashMap
|
||||
variables {α : Type u} {β : Type v}
|
||||
|
||||
def empty : PersistentHashMap α β := {}
|
||||
|
||||
instance : Inhabited (PersistentHashMap α β) := ⟨{}⟩
|
||||
|
||||
def mkEmptyEntries {α β} : Node α β :=
|
||||
Node.entries mkEmptyEntriesArray
|
||||
|
||||
abbrev mul2Shift (i : USize) (shift : USize) : USize := USize.shift_left i shift
|
||||
abbrev div2Shift (i : USize) (shift : USize) : USize := USize.shift_right i shift
|
||||
abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shift_left 1 shift) - 1)
|
||||
|
||||
inductive IsCollisionNode : Node α β → Prop
|
||||
| mk (keys : Array α) (vals : Array β) (h : keys.size = vals.size) : IsCollisionNode (Node.collision keys vals h)
|
||||
|
||||
abbrev CollisionNode (α β) := { n : Node α β // IsCollisionNode n }
|
||||
|
||||
inductive IsEntriesNode : Node α β → Prop
|
||||
| mk (entries : Array (Entry α β (Node α β))) : IsEntriesNode (Node.entries entries)
|
||||
|
||||
abbrev EntriesNode (α β) := { n : Node α β // IsEntriesNode n }
|
||||
|
||||
private theorem fsetSizeEq {ks : Array α} {vs : Array β} (h : ks.size = vs.size) (i : Fin ks.size) (j : Fin vs.size) (k : α) (v : β)
|
||||
: (ks.fset i k).size = (vs.fset j v).size :=
|
||||
have h₁ : (ks.fset i k).size = ks.size from Array.szFSetEq _ _ _;
|
||||
have h₂ : (vs.fset j v).size = vs.size from Array.szFSetEq _ _ _;
|
||||
(h₁.trans h).trans h₂.symm
|
||||
|
||||
private theorem pushSizeEq {ks : Array α} {vs : Array β} (h : ks.size = vs.size) (k : α) (v : β) : (ks.push k).size = (vs.push v).size :=
|
||||
have h₁ : (ks.push k).size = ks.size + 1 from Array.szPushEq _ _;
|
||||
have h₂ : (vs.push v).size = vs.size + 1 from Array.szPushEq _ _;
|
||||
have h₃ : ks.size + 1 = vs.size + 1 from h ▸ rfl;
|
||||
(h₁.trans h₃).trans h₂.symm
|
||||
|
||||
partial def insertAtCollisionNodeAux [HasBeq α] : CollisionNode α β → Nat → α → β → CollisionNode α β
|
||||
| n@⟨Node.collision keys vals heq, _⟩ i k v :=
|
||||
if h : i < keys.size then
|
||||
let idx : Fin keys.size := ⟨i, h⟩;
|
||||
let k' := keys.fget idx;
|
||||
if k == k' then
|
||||
let j : Fin vals.size := ⟨i, heq ▸ h⟩;
|
||||
⟨Node.collision (keys.fset idx k) (vals.fset j v) (fsetSizeEq heq idx j k v), IsCollisionNode.mk _ _ _⟩
|
||||
else insertAtCollisionNodeAux n (i+1) k v
|
||||
else
|
||||
⟨Node.collision (keys.push k) (vals.push v) (pushSizeEq heq k v), IsCollisionNode.mk _ _ _⟩
|
||||
| ⟨Node.entries _, h⟩ _ _ _ := False.elim (nomatch h)
|
||||
|
||||
def insertAtCollisionNode [HasBeq α] : CollisionNode α β → α → β → CollisionNode α β :=
|
||||
fun n k v => insertAtCollisionNodeAux n 0 k v
|
||||
|
||||
def getCollisionNodeSize : CollisionNode α β → Nat
|
||||
| ⟨Node.collision keys _ _, _⟩ := keys.size
|
||||
| ⟨Node.entries _, h⟩ := False.elim (nomatch h)
|
||||
|
||||
def mkCollisionNode (k₁ : α) (v₁ : β) (k₂ : α) (v₂ : β) : Node α β :=
|
||||
let ks : Array α := Array.mkEmpty maxCollisions;
|
||||
let ks := (ks.push k₁).push k₂;
|
||||
let vs : Array β := Array.mkEmpty maxCollisions;
|
||||
let vs := (vs.push v₁).push v₂;
|
||||
Node.collision ks vs rfl
|
||||
|
||||
partial def insertAux [HasBeq α] [Hashable α] : Node α β → USize → USize → α → β → Node α β
|
||||
| (Node.collision keys vals heq) _ depth k v :=
|
||||
let newNode := insertAtCollisionNode ⟨Node.collision keys vals heq, IsCollisionNode.mk _ _ _⟩ k v;
|
||||
if depth >= maxDepth || getCollisionNodeSize newNode < maxCollisions then newNode.val
|
||||
else match newNode with
|
||||
| ⟨Node.entries _, h⟩ => False.elim (nomatch h)
|
||||
| ⟨Node.collision keys vals heq, _⟩ =>
|
||||
let entries : Node α β := mkEmptyEntries;
|
||||
keys.iterate entries $ fun i k entries =>
|
||||
let v := vals.fget ⟨i.val, heq ▸ i.isLt⟩;
|
||||
let h := hash k;
|
||||
-- dbgTrace ("toCollision " ++ toString i ++ ", h: " ++ toString h ++ ", depth: " ++ toString depth ++ ", h': " ++
|
||||
-- toString (div2Shift h (shift * (depth - 1)))) $ fun _ =>
|
||||
let h := div2Shift h (shift * (depth - 1));
|
||||
insertAux entries h depth k v
|
||||
| (Node.entries entries) h depth k v :=
|
||||
let j := (mod2Shift h shift).toNat;
|
||||
Node.entries $ entries.modify j $ fun entry =>
|
||||
match entry with
|
||||
| Entry.null => Entry.entry k v
|
||||
| Entry.ref node => Entry.ref $ insertAux node (div2Shift h shift) (depth+1) k v
|
||||
| Entry.entry k' v' =>
|
||||
if k == k' then Entry.entry k v
|
||||
else Entry.ref $ mkCollisionNode k' v' k v
|
||||
|
||||
def insert [HasBeq α] [Hashable α] : PersistentHashMap α β → α → β → PersistentHashMap α β
|
||||
| { root := n, size := sz } k v := { root := insertAux n (hash k) 1 k v, size := sz + 1 }
|
||||
|
||||
partial def findAtAux [HasBeq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) : Nat → α → Option β
|
||||
| i k :=
|
||||
if h : i < keys.size then
|
||||
let k' := keys.fget ⟨i, h⟩;
|
||||
if k == k' then some (vals.fget ⟨i, heq ▸ h⟩)
|
||||
else findAtAux (i+1) k
|
||||
else none
|
||||
|
||||
partial def findAux [HasBeq α] : Node α β → USize → α → Option β
|
||||
| (Node.entries entries) h k :=
|
||||
let j := (mod2Shift h shift).toNat;
|
||||
match entries.get j with
|
||||
| Entry.null => none
|
||||
| Entry.ref node => findAux node (div2Shift h shift) k
|
||||
| Entry.entry k' v => if k == k' then some v else none
|
||||
| (Node.collision keys vals heq) _ k := findAtAux keys vals heq 0 k
|
||||
|
||||
def find [HasBeq α] [Hashable α] : PersistentHashMap α β → α → Option β
|
||||
| { root := n, .. } k := findAux n (hash k) k
|
||||
|
||||
section
|
||||
variables {m : Type w → Type w'} [Monad m]
|
||||
variables {σ : Type w}
|
||||
|
||||
@[specialize] partial def mfoldlAux (f : σ → α → β → m σ) : Node α β → σ → m σ
|
||||
| (Node.collision keys vals heq) acc := keys.miterate acc $ fun i k acc => f acc k (vals.fget ⟨i.val, heq ▸ i.isLt⟩)
|
||||
| (Node.entries entries) acc := entries.mfoldl (fun acc entry =>
|
||||
match entry with
|
||||
| Entry.null => pure acc
|
||||
| Entry.entry k v => f acc k v
|
||||
| Entry.ref node => mfoldlAux node acc)
|
||||
acc
|
||||
|
||||
@[specialize] def mfoldl (map : PersistentHashMap α β) (f : σ → α → β → m σ) (acc : σ) : m σ :=
|
||||
mfoldlAux f map.root acc
|
||||
|
||||
@[specialize] def foldl (map : PersistentHashMap α β) (f : σ → α → β → σ) (acc : σ) : σ :=
|
||||
Id.run $ map.mfoldl f acc
|
||||
end
|
||||
|
||||
def toList (m : PersistentHashMap α β) : List (α × β) :=
|
||||
m.foldl (fun ps k v => (k, v) :: ps) []
|
||||
|
||||
structure Stats :=
|
||||
(numNodes : Nat := 0)
|
||||
(numNull : Nat := 0)
|
||||
(numCollisions : Nat := 0)
|
||||
(maxDepth : Nat := 0)
|
||||
|
||||
partial def collectStats : Node α β → Stats → Nat → Stats
|
||||
| (Node.collision keys _ _) stats depth :=
|
||||
{ numNodes := stats.numNodes + 1,
|
||||
numCollisions := stats.numCollisions + keys.size - 1,
|
||||
maxDepth := Nat.max stats.maxDepth depth,
|
||||
.. stats }
|
||||
| (Node.entries entries) stats depth :=
|
||||
let stats :=
|
||||
{ numNodes := stats.numNodes + 1,
|
||||
maxDepth := Nat.max stats.maxDepth depth,
|
||||
.. stats };
|
||||
entries.foldl (fun stats entry =>
|
||||
match entry with
|
||||
| Entry.null => { numNull := stats.numNull + 1, .. stats }
|
||||
| Entry.ref node => collectStats node stats (depth + 1)
|
||||
| Entry.entry _ _ => stats)
|
||||
stats
|
||||
|
||||
def stats (m : PersistentHashMap α β) : Stats :=
|
||||
collectStats m.root {} 1
|
||||
|
||||
def Stats.toString (s : Stats) : String :=
|
||||
"{ nodes := " ++ toString s.numNodes ++ ", null := " ++ toString s.numNull ++
|
||||
", collisions := " ++ toString s.numCollisions ++ ", depth := " ++ toString s.maxDepth ++ "}"
|
||||
|
||||
instance : HasToString Stats := ⟨Stats.toString⟩
|
||||
|
||||
end PersistentHashMap
|
||||
7
library/init/data/persistenthashmap/default.lean
Normal file
7
library/init/data/persistenthashmap/default.lean
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
/-
|
||||
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import init.data.persistenthashmap.basic
|
||||
43
tests/playground/phashmap.lean
Normal file
43
tests/playground/phashmap.lean
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import init.data.persistenthashmap
|
||||
import init.lean.format
|
||||
open Lean PersistentHashMap
|
||||
|
||||
abbrev Map := PersistentHashMap Nat Nat
|
||||
|
||||
partial def formatMap : Node Nat Nat → Format
|
||||
| (Node.collision keys vals _) := Format.sbracket $
|
||||
keys.size.fold
|
||||
(fun i fmt =>
|
||||
let k := keys.get i;
|
||||
let v := vals.get i;
|
||||
let p := if i > 0 then fmt ++ format "," ++ Format.line else fmt;
|
||||
p ++ "c@" ++ Format.paren (format k ++ " => " ++ format v))
|
||||
Format.nil
|
||||
| (Node.entries entries) := Format.sbracket $
|
||||
entries.size.fold
|
||||
(fun i fmt =>
|
||||
let entry := entries.get i;
|
||||
let p := if i > 0 then fmt ++ format "," ++ Format.line else fmt;
|
||||
p ++
|
||||
match entry with
|
||||
| Entry.null => "<null>"
|
||||
| Entry.ref node => formatMap node
|
||||
| Entry.entry k v => Format.paren (format k ++ " => " ++ format v))
|
||||
Format.nil
|
||||
|
||||
def mkMap (n : Nat) : Map :=
|
||||
n.fold (fun i m => m.insert i (i*10)) PersistentHashMap.empty
|
||||
|
||||
def check (n : Nat) (m : Map) : IO Unit :=
|
||||
n.mfor $ fun i =>
|
||||
match m.find i with
|
||||
| none => IO.println ("failed to find " ++ toString i)
|
||||
| some v => unless (v == i*10) (IO.println ("unexpected value " ++ toString i ++ " => " ++ toString v))
|
||||
|
||||
def main (xs : List String) : IO Unit :=
|
||||
do
|
||||
let n := 1000000;
|
||||
let m := mkMap n;
|
||||
-- IO.println (formatMap m.root);
|
||||
IO.println m.stats;
|
||||
check n m
|
||||
Loading…
Add table
Reference in a new issue