feat: getThenInsertIfNew? and partition functions for the tree map (#7109)

This PR implements the `getThenInsertIfNew?` and `partition` functions
on the tree map.

---------

Co-authored-by: Paul Reichert <6992158+datokrat@users.noreply.github.com>
This commit is contained in:
Paul Reichert 2025-02-18 09:29:24 +01:00 committed by GitHub
parent 4e10e4e02e
commit d9e7ded5af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 144 additions and 3 deletions

View file

@ -196,7 +196,7 @@ section Unverified
(init : δ) (b : DHashMap α β) : δ :=
b.1.fold f init
/-- Partition a hashset into two hashsets based on a predicate. -/
/-- Partition a hash map into two hash map based on a predicate. -/
@[inline] def partition (f : (a : α) → β a → Bool)
(m : DHashMap α β) : DHashMap α β × DHashMap α β :=
m.fold (init := (∅, ∅)) fun ⟨l, r⟩ a b =>

View file

@ -137,6 +137,24 @@ def containsThenInsertIfNew (t : DTreeMap α β cmp) (a : α) (b : β a) :
let p := t.inner.containsThenInsertIfNew a b t.wf.balanced
(p.1, ⟨p.2.impl, t.wf.containsThenInsertIfNew⟩)
/--
Checks whether a key is present in a map, returning the associated value, and inserts a value for
the key if it was not found.
If the returned value is `some v`, then the returned map is unaltered. If it is `none`, then the
returned map has a new value inserted.
Equivalent to (but potentially faster than) calling `get?` followed by `insertIfNew`.
Uses the `LawfulEqCmp` instance to cast the retrieved value to the correct type.
-/
@[inline]
def getThenInsertIfNew? [LawfulEqCmp cmp] (t : DTreeMap α β cmp) (a : α) (b : β a) :
Option (β a) × DTreeMap α β cmp :=
letI : Ord α := ⟨cmp⟩
let p := t.inner.getThenInsertIfNew? a b t.wf.balanced
(p.1, ⟨p.2, t.wf.getThenInsertIfNew?⟩)
/--
Returns `true` if there is a mapping for the given key `a` or a key that is equal to `a` according
to the comparator `cmp`. There is also a `Prop`-valued version
@ -575,6 +593,13 @@ namespace Const
variable {β : Type v}
@[inline, inherit_doc DTreeMap.getThenInsertIfNew?]
def getThenInsertIfNew? (t : DTreeMap α β cmp) (a : α) (b : β) :
Option β × DTreeMap α β cmp :=
letI : Ord α := ⟨cmp⟩
let p := Impl.Const.getThenInsertIfNew? a b t.inner t.wf.balanced
(p.1, ⟨p.2, t.wf.constGetThenInsertIfNew?⟩)
@[inline, inherit_doc DTreeMap.get?]
def get? (t : DTreeMap α β cmp) (a : α) : Option β :=
letI : Ord α := ⟨cmp⟩; Impl.Const.get? a t.inner
@ -746,6 +771,15 @@ def foldr (f : δ → (a : α) → β a → δ) (init : δ) (t : DTreeMap α β
def revFold (f : δ → (a : α) → β a → δ) (init : δ) (t : DTreeMap α β cmp) : δ :=
foldr f init t
/-- Partitions a tree map into two tree maps based on a predicate. -/
@[inline] def partition (f : (a : α) → β a → Bool)
(t : DTreeMap α β cmp) : DTreeMap α β cmp × DTreeMap α β cmp :=
t.foldl (init := (∅, ∅)) fun ⟨l, r⟩ a b =>
if f a b then
(l.insert a b, r)
else
(l, r.insert a b)
/-- Carries out a monadic action on each mapping in the tree map in ascending order. -/
@[inline]
def forM (f : (a : α) → β a → m PUnit) (t : DTreeMap α β cmp) : m PUnit :=

View file

@ -399,6 +399,25 @@ def containsThenInsertIfNew! [Ord α] (k : α) (v : β k) (t : Impl α β) :
Bool × Impl α β :=
if t.contains k then (true, t) else (false, t.insert! k v)
/-- Implementation detail of the tree map -/
@[inline]
def getThenInsertIfNew? [Ord α] [LawfulEqOrd α] (k : α) (v : β k) (t : Impl α β) (ht : t.Balanced) :
Option (β k) × Impl α β :=
match t.get? k with
| none => (none, t.insertIfNew k v ht |>.impl)
| some b => (some b, t)
/--
Slower version of `getThenInsertIfNew?` which can be used in the absence of balance
information but still assumes the preconditions of `getThenInsertIfNew?`, otherwise might panic.
-/
@[inline]
def getThenInsertIfNew?! [Ord α] [LawfulEqOrd α] (k : α) (v : β k) (t : Impl α β) :
Option (β k) × Impl α β :=
match t.get? k with
| none => (none, t.insertIfNew! k v)
| some b => (some b, t)
/-- Removes the mapping with key `k`, if it exists. -/
def erase [Ord α] (k : α) (t : Impl α β) (h : t.Balanced) :
SizedBalancedTree α β (t.size - 1) t.size :=
@ -583,6 +602,25 @@ namespace Const
variable {β : Type v}
/-- Implementation detail of the tree map -/
@[inline]
def getThenInsertIfNew? [Ord α] (k : α) (v : β) (t : Impl α (fun _ => β))
(ht : t.Balanced) : Option β × Impl α (fun _ => β) :=
match get? k t with
| none => (none, t.insertIfNew k v ht |>.impl)
| some b => (some b, t)
/--
Slower version of `getThenInsertIfNew?` which can be used in the absence of balance
information but still assumes the preconditions of `getThenInsertIfNew?`, otherwise might panic.
-/
@[inline]
def getThenInsertIfNew?! [Ord α] (k : α) (v : β) (t : Impl α (fun _ => β))
: Option β × Impl α (fun _ => β) :=
match get? k t with
| none => (none, t.insertIfNew! k v)
| some b => (some b, t)
/-- Transforms a list of mappings into a tree map. -/
@[inline] def ofArray [Ord α] (a : Array (α × β)) : Impl α (fun _ => β) :=
insertMany empty a balanced_empty |>.val

View file

@ -21,6 +21,7 @@ set_option linter.all true
universe u v w
variable {α : Type u} {β : α → Type v} {γ : α → Type w} {δ : Type w} {m : Type w → Type w}
private local instance : Coe (Type v) (α → Type v) where coe γ := fun _ => γ
namespace Std.DTreeMap.Internal
@ -74,6 +75,26 @@ theorem WF.constInsertManyIfNewUnit [Ord α] {ρ} [ForIn Id ρ α] {t : Impl α
{h} (hwf : WF t) : WF (Impl.Const.insertManyIfNewUnit t l h).val :=
(Impl.Const.insertManyIfNewUnit t l h).2 hwf fun _ _ _ hwf' => hwf'.insertIfNew
theorem WF.getThenInsertIfNew? [Ord α] [LawfulEqOrd α] {t : Impl α β} {k v} {h : t.WF} :
(t.getThenInsertIfNew? k v h.balanced).2.WF := by
simp only [Impl.getThenInsertIfNew?]
split
· exact h.insertIfNew
· exact h
section Const
variable {β : Type v}
theorem WF.constGetThenInsertIfNew? [Ord α] {t : Impl α β} {k v} {h : t.WF} :
(Impl.Const.getThenInsertIfNew? k v t h.balanced).2.WF := by
simp only [Impl.Const.getThenInsertIfNew?]
split
· exact h.insertIfNew
· exact h
end Const
end Impl
end Std.DTreeMap.Internal

View file

@ -126,6 +126,13 @@ def containsThenInsertIfNew (t : Raw α β cmp) (a : α) (b : β a) :
let p := t.inner.containsThenInsertIfNew! a b
(p.1, ⟨p.2⟩)
@[inline, inherit_doc DTreeMap.getThenInsertIfNew?]
def getThenInsertIfNew? [LawfulEqCmp cmp] (t : Raw α β cmp) (a : α) (b : β a) :
Option (β a) × Raw α β cmp :=
letI : Ord α := ⟨cmp⟩
let p := t.inner.getThenInsertIfNew?! a b
(p.1, ⟨p.2⟩)
@[inline, inherit_doc DTreeMap.contains]
def contains (t : Raw α β cmp) (a : α) : Bool :=
letI : Ord α := ⟨cmp⟩; t.inner.contains a
@ -380,6 +387,12 @@ namespace Const
variable {β : Type v}
@[inline, inherit_doc DTreeMap.Const.getThenInsertIfNew?]
def getThenInsertIfNew? (t : Raw α β cmp) (a : α) (b : β) : Option β × Raw α β cmp :=
letI : Ord α := ⟨cmp⟩
let p := Impl.Const.getThenInsertIfNew?! a b t.inner
(p.1, ⟨p.2⟩)
@[inline, inherit_doc DTreeMap.Const.get?]
def get? (t : Raw α β cmp) (a : α) : Option β :=
letI : Ord α := ⟨cmp⟩; Impl.Const.get? a t.inner
@ -541,6 +554,14 @@ def foldr (f : δ → (a : α) → β a → δ) (init : δ) (t : Raw α β cmp)
def revFold (f : δ → (a : α) → β a → δ) (init : δ) (t : Raw α β cmp) : δ :=
foldr f init t
@[inline, inherit_doc DTreeMap.partition]
def partition (f : (a : α) → β a → Bool) (t : Raw α β cmp) : Raw α β cmp × Raw α β cmp :=
t.foldl (init := (∅, ∅)) fun ⟨l, r⟩ a b =>
if f a b then
(l.insert a b, r)
else
(l, r.insert a b)
@[inline, inherit_doc DTreeMap.forM]
def forM (f : (a : α) → β a → m PUnit) (t : Raw α β cmp) : m PUnit :=
t.inner.forM f

View file

@ -103,6 +103,12 @@ def containsThenInsertIfNew (t : TreeMap α β cmp) (a : α) (b : β) :
let p := t.inner.containsThenInsertIfNew a b
(p.1, ⟨p.2⟩)
@[inline, inherit_doc DTreeMap.getThenInsertIfNew?]
def getThenInsertIfNew? (t : TreeMap α β cmp) (a : α) (b : β) : Option β × TreeMap α β cmp :=
letI : Ord α := ⟨cmp⟩
let p := DTreeMap.Const.getThenInsertIfNew? t.inner a b
(p.1, ⟨p.2⟩)
@[inline, inherit_doc DTreeMap.contains]
def contains (l : TreeMap α β cmp) (a : α) : Bool :=
l.inner.contains a
@ -394,6 +400,10 @@ def foldr (f : δ → (a : α) → β → δ) (init : δ) (t : TreeMap α β cmp
def revFold (f : δ → (a : α) → β → δ) (init : δ) (t : TreeMap α β cmp) : δ :=
foldr f init t
@[inline, inherit_doc DTreeMap.partition]
def partition (f : (a : α) → β → Bool) (t : TreeMap α β cmp) : TreeMap α β cmp × TreeMap α β cmp :=
let p := t.inner.partition f; (⟨p.1⟩, ⟨p.2⟩)
@[inline, inherit_doc DTreeMap.forM]
def forM (f : α → β → m PUnit) (t : TreeMap α β cmp) : m PUnit :=
t.inner.forM f

View file

@ -107,11 +107,10 @@ instance : LawfulSingleton (α × β) (Raw α β cmp) where
@[inline, inherit_doc DTreeMap.Raw.insertIfNew]
def insertIfNew (t : Raw α β cmp) (a : α) (b : β) : Raw α β cmp :=
letI : Ord α := ⟨cmp⟩; ⟨t.inner.insertIfNew a b⟩
⟨t.inner.insertIfNew a b⟩
@[inline, inherit_doc DTreeMap.Raw.containsThenInsert]
def containsThenInsert (t : Raw α β cmp) (a : α) (b : β) : Bool × Raw α β cmp :=
letI : Ord α := ⟨cmp⟩
let p := t.inner.containsThenInsert a b
(p.1, ⟨p.2⟩)
@ -121,6 +120,11 @@ def containsThenInsertIfNew (t : Raw α β cmp) (a : α) (b : β) :
let p := t.inner.containsThenInsertIfNew a b
(p.1, ⟨p.2⟩)
@[inline, inherit_doc DTreeMap.Raw.getThenInsertIfNew?]
def getThenInsertIfNew? (t : Raw α β cmp) (a : α) (b : β) : Option β × Raw α β cmp :=
let p := DTreeMap.Raw.Const.getThenInsertIfNew? t.inner a b
(p.1, ⟨p.2⟩)
@[inline, inherit_doc DTreeMap.Raw.contains]
def contains (l : Raw α β cmp) (a : α) : Bool :=
l.inner.contains a
@ -402,6 +406,10 @@ def foldr (f : δ → (a : α) → β → δ) (init : δ) (t : Raw α β cmp) :
def revFold (f : δ → (a : α) → β → δ) (init : δ) (t : Raw α β cmp) : δ :=
foldr f init t
@[inline, inherit_doc DTreeMap.Raw.partition]
def partition (f : (a : α) → β → Bool) (t : Raw α β cmp) : Raw α β cmp × Raw α β cmp :=
let p := t.inner.partition f; (⟨p.1⟩, ⟨p.2⟩)
@[inline, inherit_doc DTreeMap.Raw.forM]
def forM (f : α → β → m PUnit) (t : Raw α β cmp) : m PUnit :=
t.inner.forM f

View file

@ -368,6 +368,11 @@ def foldr (f : δ → (a : α) → δ) (init : δ) (t : TreeSet α cmp) : δ :=
def revFold (f : δ → (a : α) → δ) (init : δ) (t : TreeSet α cmp) : δ :=
foldr f init t
/-- Partitions a tree set into two tree sets based on a predicate. -/
@[inline]
def partition (f : (a : α) → Bool) (t : TreeSet α cmp) : TreeSet α cmp × TreeSet α cmp :=
let p := t.inner.partition fun a _ => f a; (⟨p.1⟩, ⟨p.2⟩)
/-- Carries out a monadic action on each element in the tree set in ascending order. -/
@[inline]
def forM (f : α → m PUnit) (t : TreeSet α cmp) : m PUnit :=

View file

@ -267,6 +267,10 @@ def foldr (f : δ → (a : α) → δ) (init : δ) (t : Raw α cmp) : δ :=
def revFold (f : δ → (a : α) → δ) (init : δ) (t : Raw α cmp) : δ :=
foldr f init t
@[inline, inherit_doc TreeSet.partition]
def partition (f : (a : α) → Bool) (t : Raw α cmp) : Raw α cmp × Raw α cmp :=
let p := t.inner.partition fun a _ => f a; (⟨p.1⟩, ⟨p.2⟩)
@[inline, inherit_doc TreeSet.empty]
def forM (f : α → m PUnit) (t : Raw α cmp) : m PUnit :=
t.inner.forM (fun a _ => f a)