feat: remove runtime bounds checks and partial from qsort (#6241)
This PR refactors `Array.qsort` to remove runtime array bounds checks, and avoids the use of `partial`. We use the `Vector` API, along with auto_params, to avoid having to write any proofs. The new code benchmarks indistinguishably from the old.
This commit is contained in:
parent
7b8504cf06
commit
3ee2842e77
6 changed files with 52 additions and 27 deletions
|
|
@ -4,46 +4,46 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Array.Basic
|
||||
import Init.Data.Vector.Basic
|
||||
import Init.Data.Ord
|
||||
|
||||
namespace Array
|
||||
-- TODO: remove the [Inhabited α] parameters as soon as we have the tactic framework for automating proof generation and using Array.fget
|
||||
|
||||
def qpartition (as : Array α) (lt : α → α → Bool) (lo hi : Nat) : Nat × Array α :=
|
||||
if h : as.size = 0 then (0, as) else have : Inhabited α := ⟨as[0]'(by revert h; cases as.size <;> simp)⟩ -- TODO: remove
|
||||
private def qpartition {n} (as : Vector α n) (lt : α → α → Bool) (lo hi : Nat)
|
||||
(hlo : lo < n := by omega) (hhi : hi < n := by omega) : {n : Nat // lo ≤ n} × Vector α n :=
|
||||
let mid := (lo + hi) / 2
|
||||
let as := if lt (as.get! mid) (as.get! lo) then as.swapIfInBounds lo mid else as
|
||||
let as := if lt (as.get! hi) (as.get! lo) then as.swapIfInBounds lo hi else as
|
||||
let as := if lt (as.get! mid) (as.get! hi) then as.swapIfInBounds mid hi else as
|
||||
let pivot := as.get! hi
|
||||
let rec loop (as : Array α) (i j : Nat) :=
|
||||
let as := if lt as[mid] as[lo] then as.swap lo mid else as
|
||||
let as := if lt as[hi] as[lo] then as.swap lo hi else as
|
||||
let as := if lt as[mid] as[hi] then as.swap mid hi else as
|
||||
let pivot := as[hi]
|
||||
let rec loop (as : Vector α n) (i j : Nat)
|
||||
(ilo : lo ≤ i := by omega) (jh : j < n := by omega) (w : i ≤ j := by omega) :=
|
||||
if h : j < hi then
|
||||
if lt (as.get! j) pivot then
|
||||
let as := as.swapIfInBounds i j
|
||||
loop as (i+1) (j+1)
|
||||
if lt as[j] pivot then
|
||||
loop (as.swap i j) (i+1) (j+1)
|
||||
else
|
||||
loop as i (j+1)
|
||||
else
|
||||
let as := as.swapIfInBounds i hi
|
||||
(i, as)
|
||||
termination_by hi - j
|
||||
decreasing_by all_goals simp_wf; decreasing_trivial_pre_omega
|
||||
(⟨i, ilo⟩, as.swap i hi)
|
||||
loop as lo lo
|
||||
|
||||
@[inline] partial def qsort (as : Array α) (lt : α → α → Bool) (low := 0) (high := as.size - 1) : Array α :=
|
||||
let rec @[specialize] sort (as : Array α) (low high : Nat) :=
|
||||
if low < high then
|
||||
let p := qpartition as lt low high;
|
||||
-- TODO: fix `partial` support in the equation compiler, it breaks if we use `let (mid, as) := partition as lt low high`
|
||||
let mid := p.1
|
||||
let as := p.2
|
||||
if mid >= high then as
|
||||
@[inline] def qsort (as : Array α) (lt : α → α → Bool := by exact (· < ·))
|
||||
(low := 0) (high := as.size - 1) : Array α :=
|
||||
let rec @[specialize] sort {n} (as : Vector α n) (lo hi : Nat)
|
||||
(hlo : lo < n := by omega) (hhi : hi < n := by omega) :=
|
||||
if h₁ : lo < hi then
|
||||
let ⟨⟨mid, hmid⟩, as⟩ := qpartition as lt lo hi
|
||||
if h₂ : mid ≥ hi then
|
||||
as
|
||||
else
|
||||
let as := sort as low mid
|
||||
sort as (mid+1) high
|
||||
sort (sort as lo mid) (mid+1) hi
|
||||
else as
|
||||
sort as low high
|
||||
if h : as.size = 0 then
|
||||
as
|
||||
else
|
||||
let low := min low (as.size - 1)
|
||||
let high := min high (as.size - 1)
|
||||
sort ⟨as, rfl⟩ low high |>.toArray
|
||||
|
||||
set_option linter.unusedVariables.funArgs false in
|
||||
/--
|
||||
|
|
|
|||
1
tests/bench/qsort/.gitignore
vendored
Normal file
1
tests/bench/qsort/.gitignore
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
/.lake
|
||||
15
tests/bench/qsort/Main.lean
Normal file
15
tests/bench/qsort/Main.lean
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
set_option linter.unusedVariables false
|
||||
|
||||
abbrev Elem := UInt32
|
||||
|
||||
def badRand (seed : Elem) : Elem :=
|
||||
seed * 1664525 + 1013904223
|
||||
|
||||
def mkRandomArray : Nat → Elem → Array Elem → Array Elem
|
||||
| 0, seed, as => as
|
||||
| i+1, seed, as => mkRandomArray i (badRand seed) (as.push seed)
|
||||
|
||||
def main (args : List String) : IO UInt32 := do
|
||||
let a := mkRandomArray 4000000 0 (Array.mkEmpty 4000000)
|
||||
IO.println (a.qsort (· < ·)).size
|
||||
return 0
|
||||
1
tests/bench/qsort/README.md
Normal file
1
tests/bench/qsort/README.md
Normal file
|
|
@ -0,0 +1 @@
|
|||
# insertionSort
|
||||
7
tests/bench/qsort/lakefile.toml
Normal file
7
tests/bench/qsort/lakefile.toml
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
name = "qsort"
|
||||
version = "0.1.0"
|
||||
defaultTargets = ["qsort"]
|
||||
|
||||
[[lean_exe]]
|
||||
name = "qsort"
|
||||
root = "Main"
|
||||
1
tests/bench/qsort/lean-toolchain
Normal file
1
tests/bench/qsort/lean-toolchain
Normal file
|
|
@ -0,0 +1 @@
|
|||
lean4
|
||||
Loading…
Add table
Reference in a new issue