lean4-htt/src/Lean/Data/HashSet.lean
Henrik Böving 723c340a8b
perf: fix linearity in (HashSet|HashMap).erase (#3887)
Fixes linearity issues in HashSet/HashMap erase functions.

IR before patch:
```
def Lean.HashMapImp.erase._rarg (x_1 : obj) (x_2 : obj) (x_3 : obj) (x_4 : obj) : obj :=
  let x_5 : obj := proj[0] x_3;
  inc x_5;
  let x_6 : obj := proj[1] x_3;
  inc x_6;
  let x_7 : obj := Array.size  x_6;
  inc x_4;
  let x_8 : obj := app x_2 x_4;
  let x_9 : u64 := unbox x_8;
  dec x_8;
  let x_10 : usize := _private.Lean.Data.HashMap.0.Lean.HashMapImp.mkIdx x_7 x_9 ;
  let x_11 : obj := Array.uget  x_6 x_10 ;
  inc x_11;
  inc x_4;
  inc x_1;
  let x_12 : u8 := Lean.AssocList.contains._rarg x_1 x_4 x_11;
  case x_12 : u8 of
  Bool.false →
    dec x_11;
    dec x_6;
    dec x_5;
    dec x_4;
    dec x_1;
    ret x_3
  Bool.true →
    let x_13 : u8 := isShared x_3;
    case x_13 : u8 of
    Bool.false →
      let x_14 : obj := proj[1] x_3;
      dec x_14;
      let x_15 : obj := proj[0] x_3;
      dec x_15;
      let x_16 : obj := 1;
      let x_17 : obj := Nat.sub x_5 x_16;
      dec x_5;
      let x_18 : obj := Lean.AssocList.erase._rarg x_1 x_4 x_11;
      let x_19 : obj := Array.uset  x_6 x_10 x_18 ;
      set x_3[1] := x_19;
      set x_3[0] := x_17;
      ret x_3
    Bool.true →
      dec x_3;
      let x_20 : obj := 1;
      let x_21 : obj := Nat.sub x_5 x_20;
      dec x_5;
      let x_22 : obj := Lean.AssocList.erase._rarg x_1 x_4 x_11;
      let x_23 : obj := Array.uset  x_6 x_10 x_22 ;
      let x_24 : obj := ctor_0[Lean.HashMapImp.mk] x_21 x_23;
      ret x_24
```

IR after the patch:
```
def Lean.HashMapImp.erase._rarg (x_1 : obj) (x_2 : obj) (x_3 : obj) (x_4 : obj) : obj :=
  let x_5 : u8 := isShared x_3;
  case x_5 : u8 of
  Bool.false →
    let x_6 : obj := proj[0] x_3;
    let x_7 : obj := proj[1] x_3;
    let x_8 : obj := Array.size  x_7;
    inc x_4;
    let x_9 : obj := app x_2 x_4;
    let x_10 : u64 := unbox x_9;
    dec x_9;
    let x_11 : usize := _private.Lean.Data.HashMap.0.Lean.HashMapImp.mkIdx x_8 x_10 ;
    let x_12 : obj := Array.uget  x_7 x_11 ;
    inc x_12;
    inc x_4;
    inc x_1;
    let x_13 : u8 := Lean.AssocList.contains._rarg x_1 x_4 x_12;
    case x_13 : u8 of
    Bool.false →
      dec x_12;
      dec x_4;
      dec x_1;
      ret x_3
    Bool.true →
      let x_14 : obj := 1;
      let x_15 : obj := Nat.sub x_6 x_14;
      dec x_6;
      let x_16 : obj := Lean.AssocList.erase._rarg x_1 x_4 x_12;
      let x_17 : obj := Array.uset  x_7 x_11 x_16 ;
      set x_3[1] := x_17;
      set x_3[0] := x_15;
      ret x_3
  Bool.true →
    let x_18 : obj := proj[0] x_3;
    let x_19 : obj := proj[1] x_3;
    inc x_19;
    inc x_18;
    dec x_3;
    let x_20 : obj := Array.size  x_19;
    inc x_4;
    let x_21 : obj := app x_2 x_4;
    let x_22 : u64 := unbox x_21;
    dec x_21;
    let x_23 : usize := _private.Lean.Data.HashMap.0.Lean.HashMapImp.mkIdx x_20 x_22 ;
    let x_24 : obj := Array.uget  x_19 x_23 ;
    inc x_24;
    inc x_4;
    inc x_1;
    let x_25 : u8 := Lean.AssocList.contains._rarg x_1 x_4 x_24;
    case x_25 : u8 of
    Bool.false →
      dec x_24;
      dec x_4;
      dec x_1;
      let x_26 : obj := ctor_0[Lean.HashMapImp.mk] x_18 x_19;
      ret x_26
    Bool.true →
      let x_27 : obj := 1;
      let x_28 : obj := Nat.sub x_18 x_27;
      dec x_18;
      let x_29 : obj := Lean.AssocList.erase._rarg x_1 x_4 x_24;
      let x_30 : obj := Array.uset  x_19 x_23 x_29 ;
      let x_31 : obj := ctor_0[Lean.HashMapImp.mk] x_28 x_30;
      ret x_31
```

Previously `x_6` (the buckets array) always gets `inc`remented, now only
if the HashMap itself is shared.
2024-04-12 08:54:21 +00:00

209 lines
7.3 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
-/
prelude
import Init.Data.Nat.Power2
import Init.Data.List.Control
namespace Lean
universe u v w
def HashSetBucket (α : Type u) :=
{ b : Array (List α) // b.size.isPowerOfTwo }
def HashSetBucket.update {α : Type u} (data : HashSetBucket α) (i : USize) (d : List α) (h : i.toNat < data.val.size) : HashSetBucket α :=
⟨ data.val.uset i d h,
by erw [Array.size_set]; apply data.property ⟩
structure HashSetImp (α : Type u) where
size : Nat
buckets : HashSetBucket α
def mkHashSetImp {α : Type u} (capacity := 8) : HashSetImp α :=
{ size := 0
buckets :=
⟨mkArray ((capacity * 4) / 3).nextPowerOfTwo [],
by simp; apply Nat.isPowerOfTwo_nextPowerOfTwo⟩ }
namespace HashSetImp
variable {α : Type u}
/- Remark: we use a C implementation because this function is performance critical. -/
@[extern "lean_hashset_mk_idx"]
private def mkIdx {sz : Nat} (hash : UInt64) (h : sz.isPowerOfTwo) : { u : USize // u.toNat < sz } :=
-- TODO: avoid `if` in the reference implementation
let u := hash.toUSize &&& (sz.toUSize - 1)
if h' : u.toNat < sz then
⟨u, h'⟩
else
⟨0, by simp [USize.toNat, OfNat.ofNat, USize.ofNat, Fin.ofNat']; apply Nat.pos_of_isPowerOfTwo h⟩
@[inline] def reinsertAux (hashFn : α → UInt64) (data : HashSetBucket α) (a : α) : HashSetBucket α :=
let ⟨i, h⟩ := mkIdx (hashFn a) data.property
data.update i (a :: data.val[i]) h
@[inline] def foldBucketsM {δ : Type w} {m : Type w → Type w} [Monad m] (data : HashSetBucket α) (d : δ) (f : δ → α → m δ) : m δ :=
data.val.foldlM (init := d) fun d as => as.foldlM f d
@[inline] def foldBuckets {δ : Type w} (data : HashSetBucket α) (d : δ) (f : δ → α → δ) : δ :=
Id.run $ foldBucketsM data d f
@[inline] def foldM {δ : Type w} {m : Type w → Type w} [Monad m] (f : δ → α → m δ) (d : δ) (h : HashSetImp α) : m δ :=
foldBucketsM h.buckets d f
@[inline] def fold {δ : Type w} (f : δ → α → δ) (d : δ) (m : HashSetImp α) : δ :=
foldBuckets m.buckets d f
@[inline] def forBucketsM {m : Type w → Type w} [Monad m] (data : HashSetBucket α) (f : α → m PUnit) : m PUnit :=
data.val.forM fun as => as.forM f
@[inline] def forM {m : Type w → Type w} [Monad m] (f : α → m PUnit) (h : HashSetImp α) : m PUnit :=
forBucketsM h.buckets f
def find? [BEq α] [Hashable α] (m : HashSetImp α) (a : α) : Option α :=
match m with
| ⟨_, buckets⟩ =>
let ⟨i, h⟩ := mkIdx (hash a) buckets.property
buckets.val[i].find? (fun a' => a == a')
def contains [BEq α] [Hashable α] (m : HashSetImp α) (a : α) : Bool :=
match m with
| ⟨_, buckets⟩ =>
let ⟨i, h⟩ := mkIdx (hash a) buckets.property
buckets.val[i].contains a
def moveEntries [Hashable α] (i : Nat) (source : Array (List α)) (target : HashSetBucket α) : HashSetBucket α :=
if h : i < source.size then
let idx : Fin source.size := ⟨i, h⟩
let es : List α := source.get idx
-- We remove `es` from `source` to make sure we can reuse its memory cells when performing es.foldl
let source := source.set idx []
let target := es.foldl (reinsertAux hash) target
moveEntries (i+1) source target
else
target
termination_by source.size - i
def expand [Hashable α] (size : Nat) (buckets : HashSetBucket α) : HashSetImp α :=
let bucketsNew : HashSetBucket α := ⟨
mkArray (buckets.val.size * 2) [],
by simp; apply Nat.mul2_isPowerOfTwo_of_isPowerOfTwo buckets.property
{ size := size,
buckets := moveEntries 0 buckets.val bucketsNew }
def insert [BEq α] [Hashable α] (m : HashSetImp α) (a : α) : HashSetImp α :=
match m with
| ⟨size, buckets⟩ =>
let ⟨i, h⟩ := mkIdx (hash a) buckets.property
let bkt := buckets.val[i]
if bkt.contains a
then ⟨size, buckets.update i (bkt.replace a a) h⟩
else
let size' := size + 1
let buckets' := buckets.update i (a :: bkt) h
if size' ≤ buckets.val.size
then { size := size', buckets := buckets' }
else expand size' buckets'
def erase [BEq α] [Hashable α] (m : HashSetImp α) (a : α) : HashSetImp α :=
match m with
| ⟨ size, buckets ⟩ =>
let ⟨i, h⟩ := mkIdx (hash a) buckets.property
let bkt := buckets.val[i]
if bkt.contains a then
⟨size - 1, buckets.update i (bkt.erase a) h⟩
else
⟨size, buckets⟩
inductive WellFormed [BEq α] [Hashable α] : HashSetImp α → Prop where
| mkWff : ∀ n, WellFormed (mkHashSetImp n)
| insertWff : ∀ m a, WellFormed m → WellFormed (insert m a)
| eraseWff : ∀ m a, WellFormed m → WellFormed (erase m a)
end HashSetImp
def HashSet (α : Type u) [BEq α] [Hashable α] :=
{ m : HashSetImp α // m.WellFormed }
open HashSetImp
def mkHashSet {α : Type u} [BEq α] [Hashable α] (capacity := 8) : HashSet α :=
⟨ mkHashSetImp capacity, WellFormed.mkWff capacity ⟩
namespace HashSet
@[inline] def empty [BEq α] [Hashable α] : HashSet α :=
mkHashSet
instance [BEq α] [Hashable α] : Inhabited (HashSet α) where
default := mkHashSet
instance [BEq α] [Hashable α] : EmptyCollection (HashSet α) := ⟨mkHashSet⟩
variable {α : Type u} {_ : BEq α} {_ : Hashable α}
@[inline] def insert (m : HashSet α) (a : α) : HashSet α :=
match m with
| ⟨ m, hw ⟩ => ⟨ m.insert a, WellFormed.insertWff m a hw ⟩
@[inline] def erase (m : HashSet α) (a : α) : HashSet α :=
match m with
| ⟨ m, hw ⟩ => ⟨ m.erase a, WellFormed.eraseWff m a hw ⟩
@[inline] def find? (m : HashSet α) (a : α) : Option α :=
match m with
| ⟨ m, _ ⟩ => m.find? a
@[inline] def contains (m : HashSet α) (a : α) : Bool :=
match m with
| ⟨ m, _ ⟩ => m.contains a
@[inline] def foldM {δ : Type w} {m : Type w → Type w} [Monad m] (f : δ → α → m δ) (init : δ) (h : HashSet α) : m δ :=
match h with
| ⟨ h, _ ⟩ => h.foldM f init
@[inline] def fold {δ : Type w} (f : δ → α → δ) (init : δ) (m : HashSet α) : δ :=
match m with
| ⟨ m, _ ⟩ => m.fold f init
@[inline] def forM {m : Type w → Type w} [Monad m] (h : HashSet α) (f : α → m PUnit) : m PUnit :=
match h with
| ⟨h, _⟩ => h.forM f
instance : ForM m (HashSet α) α where
forM := HashSet.forM
instance : ForIn m (HashSet α) α where
forIn := ForM.forIn
@[inline] def size (m : HashSet α) : Nat :=
match m with
| ⟨ {size := sz, ..}, _ ⟩ => sz
@[inline] def isEmpty (m : HashSet α) : Bool :=
m.size = 0
def toList (m : HashSet α) : List α :=
m.fold (init := []) fun r a => a::r
def toArray (m : HashSet α) : Array α :=
m.fold (init := #[]) fun r a => r.push a
def numBuckets (m : HashSet α) : Nat :=
m.val.buckets.val.size
/-- Insert many elements into a HashSet. -/
def insertMany [ForIn Id ρ α] (s : HashSet α) (as : ρ) : HashSet α := Id.run do
let mut s := s
for a in as do
s := s.insert a
return s
/--
`O(|t|)` amortized. Merge two `HashSet`s.
-/
@[inline]
def merge {α : Type u} [BEq α] [Hashable α] (s t : HashSet α) : HashSet α :=
t.fold (init := s) fun s a => s.insert a
-- We don't use `insertMany` here because it gives weird universes.