lean4-htt/library/init/data/rbmap/basic.lean
Leonardo de Moura 778e7e41f9 refactor(library/init/data/rbmap/basic): pass ins node-cell to balance1 and balance2.
The idea is to reuse the cell. The trick is like the one we used for
improving state_t. It seems to work pretty well. Now, the Lean
version is 29% slower than the C++ one.

cc @kha
2019-02-20 18:27:58 -08:00

226 lines
8.2 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2017 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import init.data.ordering.basic init.coe init.data.option.basic
universes u v w w'
inductive rbcolor
| red | black
inductive rbnode (α : Type u) (β : α → Type v)
| leaf {} : rbnode
| node (color : rbcolor) (lchild : rbnode) (key : α) (val : β key) (rchild : rbnode) : rbnode
namespace rbnode
variables {α : Type u} {β : α → Type v} {σ : Type w}
open rbcolor nat
def depth (f : nat → nat → nat) : rbnode α β → nat
| leaf := 0
| (node _ l _ _ r) := succ (f (depth l) (depth r))
protected def min : rbnode α β → option (Σ k : α, β k)
| leaf := none
| (node _ leaf k v _) := some ⟨k, v⟩
| (node _ l k v _) := min l
protected def max : rbnode α β → option (Σ k : α, β k)
| leaf := none
| (node _ _ k v leaf) := some ⟨k, v⟩
| (node _ _ k v r) := max r
@[specialize] def fold (f : Π (k : α), β k → σσ) : rbnode α β → σσ
| leaf b := b
| (node _ l k v r) b := fold r (f k v (fold l b))
@[specialize] def mfold {m : Type w → Type w'} [monad m] (f : Π (k : α), β k → σ → m σ) : rbnode α β → σ → m σ
| leaf b := pure b
| (node _ l k v r) b := do b₁ ← mfold l b, b₂ ← f k v b₁, mfold r b₂
@[specialize] def rev_fold (f : Π (k : α), β k → σσ) : rbnode α β → σσ
| leaf b := b
| (node _ l k v r) b := rev_fold l (f k v (rev_fold r b))
@[specialize] def all (p : Π k : α, β k → bool) : rbnode α β → bool
| leaf := tt
| (node _ l k v r) := p k v && all l && all r
@[specialize] def any (p : Π k : α, β k → bool) : rbnode α β → bool
| leaf := ff
| (node _ l k v r) := p k v || any l || any r
def balance1 : rbnode α β → rbnode α β → rbnode α β
| (node _ _ kv vv t) (node _ (node red l kx vx r₁) ky vy r₂) := node red (node black l kx vx r₁) ky vy (node black r₂ kv vv t)
| (node _ _ kv vv t) (node _ l₁ ky vy (node red l₂ kx vx r)) := node red (node black l₁ ky vy l₂) kx vx (node black r kv vv t)
| (node _ _ kv vv t) (node _ l ky vy r) := node black (node red l ky vy r) kv vv t
| _ _ := leaf -- unreachable
def balance2 : rbnode α β → rbnode α β → rbnode α β
| (node _ t kv vv _) (node _ (node red l kx₁ vx₁ r₁) ky vy r₂) := node red (node black t kv vv l) kx₁ vx₁ (node black r₁ ky vy r₂)
| (node _ t kv vv _) (node _ l₁ ky vy (node red l₂ kx₂ vx₂ r₂)) := node red (node black t kv vv l₁) ky vy (node black l₂ kx₂ vx₂ r₂)
| (node _ t kv vv _) (node _ l ky vy r) := node black t kv vv (node red l ky vy r)
| _ _ := leaf -- unreachable
def is_red : rbnode α β → bool
| (node red _ _ _ _) := tt
| _ := ff
section insert
variables (lt : αα → Prop) [decidable_rel lt]
def ins : rbnode α β → Π k : α, β k → rbnode α β
| leaf kx vx := node red leaf kx vx leaf
| (node red a ky vy b) kx vx :=
(match cmp_using lt kx ky with
| ordering.lt := node red (ins a kx vx) ky vy b
| ordering.eq := node red a kx vx b
| ordering.gt := node red a ky vy (ins b kx vx))
| (node black a ky vy b) kx vx :=
match cmp_using lt kx ky with
| ordering.lt :=
if is_red a then balance1 (node black leaf ky vy b) (ins a kx vx)
else node black (ins a kx vx) ky vy b
| ordering.eq := node black a kx vx b
| ordering.gt :=
if is_red b then balance2 (node black a ky vy leaf) (ins b kx vx)
else node black a ky vy (ins b kx vx)
def set_black : rbnode α β → rbnode α β
| (node _ l k v r) := node black l k v r
| e := e
def insert (t : rbnode α β) (k : α) (v : β k) : rbnode α β :=
if is_red t then set_black (ins lt t k v)
else ins lt t k v
end insert
section membership
variable (lt : αα → Prop)
variable [decidable_rel lt]
def find_core : rbnode α β → Π k : α, option (Σ k : α, β k)
| leaf x := none
| (node _ a ky vy b) x :=
(match cmp_using lt x ky with
| ordering.lt := find_core a x
| ordering.eq := some ⟨ky, vy⟩
| ordering.gt := find_core b x)
def find {β : Type v} : rbnode α (λ _, β) → α → option β
| leaf x := none
| (node _ a ky vy b) x :=
(match cmp_using lt x ky with
| ordering.lt := find a x
| ordering.eq := some vy
| ordering.gt := find b x)
def lower_bound : rbnode α β → α → option (sigma β) → option (sigma β)
| leaf x lb := lb
| (node _ a ky vy b) x lb :=
(match cmp_using lt x ky with
| ordering.lt := lower_bound a x lb
| ordering.eq := some ⟨ky, vy⟩
| ordering.gt := lower_bound b x (some ⟨ky, vy⟩))
end membership
inductive well_formed (lt : αα → Prop) : rbnode α β → Prop
| leaf_wff : well_formed leaf
| insert_wff {n n' : rbnode α β} {k : α} {v : β k} [decidable_rel lt] : well_formed n → n' = insert lt n k v → well_formed n'
end rbnode
open rbnode
/- TODO(Leo): define d_rbmap -/
def rbmap (α : Type u) (β : Type v) (lt : αα → Prop) : Type (max u v) :=
{t : rbnode α (λ _, β) // t.well_formed lt }
@[inline] def mk_rbmap (α : Type u) (β : Type v) (lt : αα → Prop) : rbmap α β lt :=
⟨leaf, well_formed.leaf_wff lt⟩
namespace rbmap
variables {α : Type u} {β : Type v} {σ : Type w} {lt : αα → Prop}
def depth (f : nat → nat → nat) (t : rbmap α β lt) : nat :=
t.val.depth f
@[inline] def fold (f : α → β → σσ) : rbmap α β lt → σσ
| ⟨t, _⟩ b := t.fold f b
@[inline] def rev_fold (f : α → β → σσ) : rbmap α β lt → σσ
| ⟨t, _⟩ b := t.rev_fold f b
@[inline] def mfold {m : Type w → Type w'} [monad m] (f : α → β → σ → m σ) : rbmap α β lt → σ → m σ
| ⟨t, _⟩ b := t.mfold f b
@[inline] def mfor {m : Type w → Type w'} [monad m] (f : α → β → m σ) (t : rbmap α β lt) : m punit :=
t.mfold (λ k v _, f k v *> pure ⟨⟩) ⟨⟩
@[inline] def empty : rbmap α β lt → bool
| ⟨leaf, _⟩ := tt
| _ := ff
@[specialize] def to_list : rbmap α β lt → list (α × β)
| ⟨t, _⟩ := t.rev_fold (λ k v ps, (k, v)::ps) []
@[inline] protected def min : rbmap α β lt → option (α × β)
| ⟨t, _⟩ :=
match t.min with
| some ⟨k, v⟩ := some (k, v)
| none := none
@[inline] protected def max : rbmap α β lt → option (α × β)
| ⟨t, _⟩ :=
match t.max with
| some ⟨k, v⟩ := some (k, v)
| none := none
instance [has_repr α] [has_repr β] : has_repr (rbmap α β lt) :=
⟨λ t, "rbmap_of " ++ repr t.to_list⟩
variables [decidable_rel lt]
def insert : rbmap α β lt → α → β → rbmap α β lt
| ⟨t, w⟩ k v := ⟨t.insert lt k v, well_formed.insert_wff w rfl⟩
@[specialize] def of_list : list (α × β) → rbmap α β lt
| [] := mk_rbmap _ _ _
| (⟨k,v⟩::xs) := (of_list xs).insert k v
def find_core : rbmap α β lt → α → option (Σ k : α, β)
| ⟨t, _⟩ x := t.find_core lt x
def find : rbmap α β lt → α → option β
| ⟨t, _⟩ x := t.find lt x
/-- (lower_bound k) retrieves the kv pair of the largest key smaller than or equal to `k`,
if it exists. -/
def lower_bound : rbmap α β lt → α → option (Σ k : α, β)
| ⟨t, _⟩ x := t.lower_bound lt x none
@[inline] def contains (t : rbmap α β lt) (a : α) : bool :=
(t.find a).is_some
def from_list (l : list (α × β)) (lt : αα → Prop) [decidable_rel lt] : rbmap α β lt :=
l.foldl (λ r p, r.insert p.1 p.2) (mk_rbmap α β lt)
@[inline] def all : rbmap α β lt → (α → β → bool) → bool
| ⟨t, _⟩ p := t.all p
@[inline] def any : rbmap α β lt → (α → β → bool) → bool
| ⟨t, _⟩ p := t.any p
end rbmap
def rbmap_of {α : Type u} {β : Type v} (l : list (α × β)) (lt : αα → Prop) [decidable_rel lt] : rbmap α β lt :=
rbmap.from_list l lt