From d9e7ded5afce0b321055439024374d7ec00805be Mon Sep 17 00:00:00 2001 From: Paul Reichert Date: Tue, 18 Feb 2025 09:29:24 +0100 Subject: [PATCH] feat: getThenInsertIfNew? and partition functions for the tree map (#7109) This PR implements the `getThenInsertIfNew?` and `partition` functions on the tree map. --------- Co-authored-by: Paul Reichert <6992158+datokrat@users.noreply.github.com> --- src/Std/Data/DHashMap/Basic.lean | 2 +- src/Std/Data/DTreeMap/Basic.lean | 34 +++++++++++++++++ .../Data/DTreeMap/Internal/Operations.lean | 38 +++++++++++++++++++ src/Std/Data/DTreeMap/Internal/WF/Defs.lean | 21 ++++++++++ src/Std/Data/DTreeMap/Raw.lean | 21 ++++++++++ src/Std/Data/TreeMap/Basic.lean | 10 +++++ src/Std/Data/TreeMap/Raw.lean | 12 +++++- src/Std/Data/TreeSet/Basic.lean | 5 +++ src/Std/Data/TreeSet/Raw.lean | 4 ++ 9 files changed, 144 insertions(+), 3 deletions(-) diff --git a/src/Std/Data/DHashMap/Basic.lean b/src/Std/Data/DHashMap/Basic.lean index b3d0cad069..5a8b318346 100644 --- a/src/Std/Data/DHashMap/Basic.lean +++ b/src/Std/Data/DHashMap/Basic.lean @@ -196,7 +196,7 @@ section Unverified (init : δ) (b : DHashMap α β) : δ := b.1.fold f init -/-- Partition a hashset into two hashsets based on a predicate. -/ +/-- Partition a hash map into two hash map based on a predicate. -/ @[inline] def partition (f : (a : α) → β a → Bool) (m : DHashMap α β) : DHashMap α β × DHashMap α β := m.fold (init := (∅, ∅)) fun ⟨l, r⟩ a b => diff --git a/src/Std/Data/DTreeMap/Basic.lean b/src/Std/Data/DTreeMap/Basic.lean index 61f06e4d17..d5d01996c8 100644 --- a/src/Std/Data/DTreeMap/Basic.lean +++ b/src/Std/Data/DTreeMap/Basic.lean @@ -137,6 +137,24 @@ def containsThenInsertIfNew (t : DTreeMap α β cmp) (a : α) (b : β a) : let p := t.inner.containsThenInsertIfNew a b t.wf.balanced (p.1, ⟨p.2.impl, t.wf.containsThenInsertIfNew⟩) +/-- +Checks whether a key is present in a map, returning the associated value, and inserts a value for +the key if it was not found. + +If the returned value is `some v`, then the returned map is unaltered. If it is `none`, then the +returned map has a new value inserted. + +Equivalent to (but potentially faster than) calling `get?` followed by `insertIfNew`. + +Uses the `LawfulEqCmp` instance to cast the retrieved value to the correct type. +-/ +@[inline] +def getThenInsertIfNew? [LawfulEqCmp cmp] (t : DTreeMap α β cmp) (a : α) (b : β a) : + Option (β a) × DTreeMap α β cmp := + letI : Ord α := ⟨cmp⟩ + let p := t.inner.getThenInsertIfNew? a b t.wf.balanced + (p.1, ⟨p.2, t.wf.getThenInsertIfNew?⟩) + /-- Returns `true` if there is a mapping for the given key `a` or a key that is equal to `a` according to the comparator `cmp`. There is also a `Prop`-valued version @@ -575,6 +593,13 @@ namespace Const variable {β : Type v} +@[inline, inherit_doc DTreeMap.getThenInsertIfNew?] +def getThenInsertIfNew? (t : DTreeMap α β cmp) (a : α) (b : β) : + Option β × DTreeMap α β cmp := + letI : Ord α := ⟨cmp⟩ + let p := Impl.Const.getThenInsertIfNew? a b t.inner t.wf.balanced + (p.1, ⟨p.2, t.wf.constGetThenInsertIfNew?⟩) + @[inline, inherit_doc DTreeMap.get?] def get? (t : DTreeMap α β cmp) (a : α) : Option β := letI : Ord α := ⟨cmp⟩; Impl.Const.get? a t.inner @@ -746,6 +771,15 @@ def foldr (f : δ → (a : α) → β a → δ) (init : δ) (t : DTreeMap α β def revFold (f : δ → (a : α) → β a → δ) (init : δ) (t : DTreeMap α β cmp) : δ := foldr f init t +/-- Partitions a tree map into two tree maps based on a predicate. -/ +@[inline] def partition (f : (a : α) → β a → Bool) + (t : DTreeMap α β cmp) : DTreeMap α β cmp × DTreeMap α β cmp := + t.foldl (init := (∅, ∅)) fun ⟨l, r⟩ a b => + if f a b then + (l.insert a b, r) + else + (l, r.insert a b) + /-- Carries out a monadic action on each mapping in the tree map in ascending order. -/ @[inline] def forM (f : (a : α) → β a → m PUnit) (t : DTreeMap α β cmp) : m PUnit := diff --git a/src/Std/Data/DTreeMap/Internal/Operations.lean b/src/Std/Data/DTreeMap/Internal/Operations.lean index b555893c1f..fa6bcf0406 100644 --- a/src/Std/Data/DTreeMap/Internal/Operations.lean +++ b/src/Std/Data/DTreeMap/Internal/Operations.lean @@ -399,6 +399,25 @@ def containsThenInsertIfNew! [Ord α] (k : α) (v : β k) (t : Impl α β) : Bool × Impl α β := if t.contains k then (true, t) else (false, t.insert! k v) +/-- Implementation detail of the tree map -/ +@[inline] +def getThenInsertIfNew? [Ord α] [LawfulEqOrd α] (k : α) (v : β k) (t : Impl α β) (ht : t.Balanced) : + Option (β k) × Impl α β := + match t.get? k with + | none => (none, t.insertIfNew k v ht |>.impl) + | some b => (some b, t) + +/-- +Slower version of `getThenInsertIfNew?` which can be used in the absence of balance +information but still assumes the preconditions of `getThenInsertIfNew?`, otherwise might panic. +-/ +@[inline] +def getThenInsertIfNew?! [Ord α] [LawfulEqOrd α] (k : α) (v : β k) (t : Impl α β) : + Option (β k) × Impl α β := + match t.get? k with + | none => (none, t.insertIfNew! k v) + | some b => (some b, t) + /-- Removes the mapping with key `k`, if it exists. -/ def erase [Ord α] (k : α) (t : Impl α β) (h : t.Balanced) : SizedBalancedTree α β (t.size - 1) t.size := @@ -583,6 +602,25 @@ namespace Const variable {β : Type v} +/-- Implementation detail of the tree map -/ +@[inline] +def getThenInsertIfNew? [Ord α] (k : α) (v : β) (t : Impl α (fun _ => β)) + (ht : t.Balanced) : Option β × Impl α (fun _ => β) := + match get? k t with + | none => (none, t.insertIfNew k v ht |>.impl) + | some b => (some b, t) + +/-- +Slower version of `getThenInsertIfNew?` which can be used in the absence of balance +information but still assumes the preconditions of `getThenInsertIfNew?`, otherwise might panic. +-/ +@[inline] +def getThenInsertIfNew?! [Ord α] (k : α) (v : β) (t : Impl α (fun _ => β)) + : Option β × Impl α (fun _ => β) := + match get? k t with + | none => (none, t.insertIfNew! k v) + | some b => (some b, t) + /-- Transforms a list of mappings into a tree map. -/ @[inline] def ofArray [Ord α] (a : Array (α × β)) : Impl α (fun _ => β) := insertMany empty a balanced_empty |>.val diff --git a/src/Std/Data/DTreeMap/Internal/WF/Defs.lean b/src/Std/Data/DTreeMap/Internal/WF/Defs.lean index 5f0feb7dd1..aabef8b3d0 100644 --- a/src/Std/Data/DTreeMap/Internal/WF/Defs.lean +++ b/src/Std/Data/DTreeMap/Internal/WF/Defs.lean @@ -21,6 +21,7 @@ set_option linter.all true universe u v w variable {α : Type u} {β : α → Type v} {γ : α → Type w} {δ : Type w} {m : Type w → Type w} +private local instance : Coe (Type v) (α → Type v) where coe γ := fun _ => γ namespace Std.DTreeMap.Internal @@ -74,6 +75,26 @@ theorem WF.constInsertManyIfNewUnit [Ord α] {ρ} [ForIn Id ρ α] {t : Impl α {h} (hwf : WF t) : WF (Impl.Const.insertManyIfNewUnit t l h).val := (Impl.Const.insertManyIfNewUnit t l h).2 hwf fun _ _ _ hwf' => hwf'.insertIfNew +theorem WF.getThenInsertIfNew? [Ord α] [LawfulEqOrd α] {t : Impl α β} {k v} {h : t.WF} : + (t.getThenInsertIfNew? k v h.balanced).2.WF := by + simp only [Impl.getThenInsertIfNew?] + split + · exact h.insertIfNew + · exact h + +section Const + +variable {β : Type v} + +theorem WF.constGetThenInsertIfNew? [Ord α] {t : Impl α β} {k v} {h : t.WF} : + (Impl.Const.getThenInsertIfNew? k v t h.balanced).2.WF := by + simp only [Impl.Const.getThenInsertIfNew?] + split + · exact h.insertIfNew + · exact h + +end Const + end Impl end Std.DTreeMap.Internal diff --git a/src/Std/Data/DTreeMap/Raw.lean b/src/Std/Data/DTreeMap/Raw.lean index f1ff8c4f67..c0b4e455f2 100644 --- a/src/Std/Data/DTreeMap/Raw.lean +++ b/src/Std/Data/DTreeMap/Raw.lean @@ -126,6 +126,13 @@ def containsThenInsertIfNew (t : Raw α β cmp) (a : α) (b : β a) : let p := t.inner.containsThenInsertIfNew! a b (p.1, ⟨p.2⟩) +@[inline, inherit_doc DTreeMap.getThenInsertIfNew?] +def getThenInsertIfNew? [LawfulEqCmp cmp] (t : Raw α β cmp) (a : α) (b : β a) : + Option (β a) × Raw α β cmp := + letI : Ord α := ⟨cmp⟩ + let p := t.inner.getThenInsertIfNew?! a b + (p.1, ⟨p.2⟩) + @[inline, inherit_doc DTreeMap.contains] def contains (t : Raw α β cmp) (a : α) : Bool := letI : Ord α := ⟨cmp⟩; t.inner.contains a @@ -380,6 +387,12 @@ namespace Const variable {β : Type v} +@[inline, inherit_doc DTreeMap.Const.getThenInsertIfNew?] +def getThenInsertIfNew? (t : Raw α β cmp) (a : α) (b : β) : Option β × Raw α β cmp := + letI : Ord α := ⟨cmp⟩ + let p := Impl.Const.getThenInsertIfNew?! a b t.inner + (p.1, ⟨p.2⟩) + @[inline, inherit_doc DTreeMap.Const.get?] def get? (t : Raw α β cmp) (a : α) : Option β := letI : Ord α := ⟨cmp⟩; Impl.Const.get? a t.inner @@ -541,6 +554,14 @@ def foldr (f : δ → (a : α) → β a → δ) (init : δ) (t : Raw α β cmp) def revFold (f : δ → (a : α) → β a → δ) (init : δ) (t : Raw α β cmp) : δ := foldr f init t +@[inline, inherit_doc DTreeMap.partition] +def partition (f : (a : α) → β a → Bool) (t : Raw α β cmp) : Raw α β cmp × Raw α β cmp := + t.foldl (init := (∅, ∅)) fun ⟨l, r⟩ a b => + if f a b then + (l.insert a b, r) + else + (l, r.insert a b) + @[inline, inherit_doc DTreeMap.forM] def forM (f : (a : α) → β a → m PUnit) (t : Raw α β cmp) : m PUnit := t.inner.forM f diff --git a/src/Std/Data/TreeMap/Basic.lean b/src/Std/Data/TreeMap/Basic.lean index 6c711d3159..0c188fc12a 100644 --- a/src/Std/Data/TreeMap/Basic.lean +++ b/src/Std/Data/TreeMap/Basic.lean @@ -103,6 +103,12 @@ def containsThenInsertIfNew (t : TreeMap α β cmp) (a : α) (b : β) : let p := t.inner.containsThenInsertIfNew a b (p.1, ⟨p.2⟩) +@[inline, inherit_doc DTreeMap.getThenInsertIfNew?] +def getThenInsertIfNew? (t : TreeMap α β cmp) (a : α) (b : β) : Option β × TreeMap α β cmp := + letI : Ord α := ⟨cmp⟩ + let p := DTreeMap.Const.getThenInsertIfNew? t.inner a b + (p.1, ⟨p.2⟩) + @[inline, inherit_doc DTreeMap.contains] def contains (l : TreeMap α β cmp) (a : α) : Bool := l.inner.contains a @@ -394,6 +400,10 @@ def foldr (f : δ → (a : α) → β → δ) (init : δ) (t : TreeMap α β cmp def revFold (f : δ → (a : α) → β → δ) (init : δ) (t : TreeMap α β cmp) : δ := foldr f init t +@[inline, inherit_doc DTreeMap.partition] +def partition (f : (a : α) → β → Bool) (t : TreeMap α β cmp) : TreeMap α β cmp × TreeMap α β cmp := + let p := t.inner.partition f; (⟨p.1⟩, ⟨p.2⟩) + @[inline, inherit_doc DTreeMap.forM] def forM (f : α → β → m PUnit) (t : TreeMap α β cmp) : m PUnit := t.inner.forM f diff --git a/src/Std/Data/TreeMap/Raw.lean b/src/Std/Data/TreeMap/Raw.lean index 51d1177fda..e927c21554 100644 --- a/src/Std/Data/TreeMap/Raw.lean +++ b/src/Std/Data/TreeMap/Raw.lean @@ -107,11 +107,10 @@ instance : LawfulSingleton (α × β) (Raw α β cmp) where @[inline, inherit_doc DTreeMap.Raw.insertIfNew] def insertIfNew (t : Raw α β cmp) (a : α) (b : β) : Raw α β cmp := - letI : Ord α := ⟨cmp⟩; ⟨t.inner.insertIfNew a b⟩ + ⟨t.inner.insertIfNew a b⟩ @[inline, inherit_doc DTreeMap.Raw.containsThenInsert] def containsThenInsert (t : Raw α β cmp) (a : α) (b : β) : Bool × Raw α β cmp := - letI : Ord α := ⟨cmp⟩ let p := t.inner.containsThenInsert a b (p.1, ⟨p.2⟩) @@ -121,6 +120,11 @@ def containsThenInsertIfNew (t : Raw α β cmp) (a : α) (b : β) : let p := t.inner.containsThenInsertIfNew a b (p.1, ⟨p.2⟩) +@[inline, inherit_doc DTreeMap.Raw.getThenInsertIfNew?] +def getThenInsertIfNew? (t : Raw α β cmp) (a : α) (b : β) : Option β × Raw α β cmp := + let p := DTreeMap.Raw.Const.getThenInsertIfNew? t.inner a b + (p.1, ⟨p.2⟩) + @[inline, inherit_doc DTreeMap.Raw.contains] def contains (l : Raw α β cmp) (a : α) : Bool := l.inner.contains a @@ -402,6 +406,10 @@ def foldr (f : δ → (a : α) → β → δ) (init : δ) (t : Raw α β cmp) : def revFold (f : δ → (a : α) → β → δ) (init : δ) (t : Raw α β cmp) : δ := foldr f init t +@[inline, inherit_doc DTreeMap.Raw.partition] +def partition (f : (a : α) → β → Bool) (t : Raw α β cmp) : Raw α β cmp × Raw α β cmp := + let p := t.inner.partition f; (⟨p.1⟩, ⟨p.2⟩) + @[inline, inherit_doc DTreeMap.Raw.forM] def forM (f : α → β → m PUnit) (t : Raw α β cmp) : m PUnit := t.inner.forM f diff --git a/src/Std/Data/TreeSet/Basic.lean b/src/Std/Data/TreeSet/Basic.lean index 6bc9fd0a79..b15c5364b3 100644 --- a/src/Std/Data/TreeSet/Basic.lean +++ b/src/Std/Data/TreeSet/Basic.lean @@ -368,6 +368,11 @@ def foldr (f : δ → (a : α) → δ) (init : δ) (t : TreeSet α cmp) : δ := def revFold (f : δ → (a : α) → δ) (init : δ) (t : TreeSet α cmp) : δ := foldr f init t +/-- Partitions a tree set into two tree sets based on a predicate. -/ +@[inline] +def partition (f : (a : α) → Bool) (t : TreeSet α cmp) : TreeSet α cmp × TreeSet α cmp := + let p := t.inner.partition fun a _ => f a; (⟨p.1⟩, ⟨p.2⟩) + /-- Carries out a monadic action on each element in the tree set in ascending order. -/ @[inline] def forM (f : α → m PUnit) (t : TreeSet α cmp) : m PUnit := diff --git a/src/Std/Data/TreeSet/Raw.lean b/src/Std/Data/TreeSet/Raw.lean index d0a073f480..b46a7b6f29 100644 --- a/src/Std/Data/TreeSet/Raw.lean +++ b/src/Std/Data/TreeSet/Raw.lean @@ -267,6 +267,10 @@ def foldr (f : δ → (a : α) → δ) (init : δ) (t : Raw α cmp) : δ := def revFold (f : δ → (a : α) → δ) (init : δ) (t : Raw α cmp) : δ := foldr f init t +@[inline, inherit_doc TreeSet.partition] +def partition (f : (a : α) → Bool) (t : Raw α cmp) : Raw α cmp × Raw α cmp := + let p := t.inner.partition fun a _ => f a; (⟨p.1⟩, ⟨p.2⟩) + @[inline, inherit_doc TreeSet.empty] def forM (f : α → m PUnit) (t : Raw α cmp) : m PUnit := t.inner.forM (fun a _ => f a)