lean4-htt/src/Init/Data/Array/BinSearch.lean
Kim Morrison 884a9ea2ff
feat: remove partial keyword and runtime bounds checks from Array.binSearch (#6193)
This PR completes the TODO in `Init.Data.Array.BinSearch`, removing the
`partial` keyword and converting runtime bounds checks to compile time
bounds checks.
2024-11-24 06:08:16 +00:00

83 lines
3.3 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Data.Array.Basic
import Init.Omega
universe u v
namespace Array
@[specialize] def binSearchAux {α : Type u} {β : Type v} (lt : αα → Bool) (found : Option α → β) (as : Array α) (k : α) :
(lo : Fin (as.size + 1)) → (hi : Fin as.size) → (lo.1 ≤ hi.1) → β
| lo, hi, h =>
let m := (lo.1 + hi.1)/2
let a := as[m]
if lt a k then
if h' : m + 1 ≤ hi.1 then
binSearchAux lt found as k ⟨m+1, by omega⟩ hi h'
else found none
else if lt k a then
if h' : m = 0 m - 1 < lo.1 then found none
else binSearchAux lt found as k lo ⟨m-1, by omega⟩ (by simp; omega)
else found (some a)
termination_by lo hi => hi.1 - lo.1
@[inline] def binSearch {α : Type} (as : Array α) (k : α) (lt : αα → Bool) (lo := 0) (hi := as.size - 1) : Option α :=
if h : lo < as.size then
let hi := if hi < as.size then hi else as.size - 1
if w : lo ≤ hi then
binSearchAux lt id as k ⟨lo, by omega⟩ ⟨hi, by simp [hi]; split <;> omega⟩ (by simp [hi]; omega)
else
none
else
none
@[inline] def binSearchContains {α : Type} (as : Array α) (k : α) (lt : αα → Bool) (lo := 0) (hi := as.size - 1) : Bool :=
if h : lo < as.size then
let hi := if hi < as.size then hi else as.size - 1
if w : lo ≤ hi then
binSearchAux lt Option.isSome as k ⟨lo, by omega⟩ ⟨hi, by simp [hi]; split <;> omega⟩ (by simp [hi]; omega)
else
false
else
false
@[specialize] private def binInsertAux {α : Type u} {m : Type u → Type v} [Monad m]
(lt : αα → Bool)
(merge : α → m α)
(add : Unit → m α)
(as : Array α)
(k : α) : (lo : Fin as.size) → (hi : Fin as.size) → (lo.1 ≤ hi.1) → (lt as[lo] k) → m (Array α)
| lo, hi, h, w =>
let mid := (lo.1 + hi.1)/2
let midVal := as[mid]
if w₁ : lt midVal k then
if h' : mid = lo then do let v ← add (); pure <| as.insertIdx (lo+1) v
else binInsertAux lt merge add as k ⟨mid, by omega⟩ hi (by simp; omega) w₁
else if w₂ : lt k midVal then
have : mid ≠ lo := fun z => by simp [midVal, z] at w₁; simp_all
binInsertAux lt merge add as k lo ⟨mid, by omega⟩ (by simp; omega) w
else do
as.modifyM mid <| fun v => merge v
termination_by lo hi => hi.1 - lo.1
@[specialize] def binInsertM {α : Type u} {m : Type u → Type v} [Monad m]
(lt : αα → Bool)
(merge : α → m α)
(add : Unit → m α)
(as : Array α)
(k : α) : m (Array α) :=
if h : as.size = 0 then do let v ← add (); pure <| as.push v
else if lt k as[0] then do let v ← add (); pure <| as.insertIdx 0 v
else if h' : !lt as[0] k then as.modifyM 0 <| merge
else if lt as[as.size - 1] k then do let v ← add (); pure <| as.push v
else if !lt k as[as.size - 1] then as.modifyM (as.size - 1) <| merge
else binInsertAux lt merge add as k ⟨0, by omega⟩ ⟨as.size - 1, by omega⟩ (by simp) (by simpa using h')
@[inline] def binInsert {α : Type u} (lt : αα → Bool) (as : Array α) (k : α) : Array α :=
Id.run <| binInsertM lt (fun _ => k) (fun _ => k) as k
end Array