fix: make sure splitTarget? skips match expressions that produce type errors at splitMatch
We can now generate the equation theorem for ``` attribute [simp] Array.heapSort.loop ``` see #998
This commit is contained in:
parent
e574c5373f
commit
7fc12014da
2 changed files with 194 additions and 11 deletions
|
|
@ -119,19 +119,21 @@ def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := do
|
|||
throwNestedTacticEx `splitMatch ex
|
||||
|
||||
/-- Return an `if-then-else` or `match-expr` to split. -/
|
||||
partial def findSplit? (env : Environment) (e : Expr) : Option Expr :=
|
||||
partial def findSplit? (env : Environment) (e : Expr) (exceptionSet : ExprSet := {}) : Option Expr :=
|
||||
if let some target := e.find? isCandidate then
|
||||
if e.isIte || e.isDIte then
|
||||
let cond := target.getArg! 1 5
|
||||
-- Try to find a nested `if` in `cond`
|
||||
findSplit? env cond |>.getD target
|
||||
findSplit? env cond exceptionSet |>.getD target
|
||||
else
|
||||
some target
|
||||
else
|
||||
none
|
||||
where
|
||||
isCandidate (e : Expr) : Bool := Id.run <| do
|
||||
if e.isIte || e.isDIte then
|
||||
if exceptionSet.contains e then
|
||||
false
|
||||
else if e.isIte || e.isDIte then
|
||||
!(e.getArg! 1 5).hasLooseBVars
|
||||
else if let some info := isMatcherAppCore? env e then
|
||||
let args := e.getAppArgs
|
||||
|
|
@ -146,15 +148,20 @@ end Split
|
|||
|
||||
open Split
|
||||
|
||||
def splitTarget? (mvarId : MVarId) : MetaM (Option (List MVarId)) := commitWhenSome? do
|
||||
if let some e := findSplit? (← getEnv) (← instantiateMVars (← getMVarType mvarId)) then
|
||||
if e.isIte || e.isDIte then
|
||||
return (← splitIfTarget? mvarId).map fun (s₁, s₂) => [s₁.mvarId, s₂.mvarId]
|
||||
partial def splitTarget? (mvarId : MVarId) : MetaM (Option (List MVarId)) := commitWhenSome? do
|
||||
let rec go (badCases : ExprSet) : MetaM (Option (List MVarId)) := do
|
||||
if let some e := findSplit? (← getEnv) (← instantiateMVars (← getMVarType mvarId)) badCases then
|
||||
if e.isIte || e.isDIte then
|
||||
return (← splitIfTarget? mvarId).map fun (s₁, s₂) => [s₁.mvarId, s₂.mvarId]
|
||||
else
|
||||
try
|
||||
splitMatch mvarId e
|
||||
catch _ =>
|
||||
go (badCases.insert e)
|
||||
else
|
||||
splitMatch mvarId e
|
||||
else
|
||||
trace[Meta.Tactic.split] "did not find term to split\n{MessageData.ofGoal mvarId}"
|
||||
return none
|
||||
trace[Meta.Tactic.split] "did not find term to split\n{MessageData.ofGoal mvarId}"
|
||||
return none
|
||||
go {}
|
||||
|
||||
def splitLocalDecl? (mvarId : MVarId) (fvarId : FVarId) : MetaM (Option (List MVarId)) := commitWhenSome? do
|
||||
withMVarContext mvarId do
|
||||
|
|
|
|||
176
tests/lean/run/heapSort.lean
Normal file
176
tests/lean/run/heapSort.lean
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
/-
|
||||
Copyright (c) 2021 Mario Carneiro. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Mario Carneiro
|
||||
-/
|
||||
|
||||
/-- A max-heap data structure. -/
|
||||
structure BinaryHeap (α) (lt : α → α → Bool) where
|
||||
arr : Array α
|
||||
|
||||
namespace BinaryHeap
|
||||
|
||||
/-- Core operation for binary heaps, expressed directly on arrays.
|
||||
Given an array which is a max-heap, push item `i` down to restore the max-heap property. -/
|
||||
def heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
|
||||
{a' : Array α // a'.size = a.size} :=
|
||||
let left := 2 * i.1 + 1
|
||||
let right := left + 1
|
||||
have left_le : i ≤ left := Nat.le_trans
|
||||
(by rw [Nat.succ_mul, Nat.one_mul]; exact Nat.le_add_left i i)
|
||||
(Nat.le_add_right ..)
|
||||
have right_le : i ≤ right := Nat.le_trans left_le (Nat.le_add_right ..)
|
||||
have i_le : i ≤ i := Nat.le_refl _
|
||||
have j : {j : Fin a.size // i ≤ j} := if h : left < a.size then
|
||||
if lt (a.get i) (a.get ⟨left, h⟩) then ⟨⟨left, h⟩, left_le⟩ else ⟨i, i_le⟩ else ⟨i, i_le⟩
|
||||
have j := if h : right < a.size then
|
||||
if lt (a.get j) (a.get ⟨right, h⟩) then ⟨⟨right, h⟩, right_le⟩ else j else j
|
||||
if h : i.1 = j then ⟨a, rfl⟩ else
|
||||
let a' := a.swap i j
|
||||
let j' := ⟨j, by rw [a.size_swap i j]; exact j.1.2⟩
|
||||
have : a'.size - j < a.size - i := by
|
||||
rw [a.size_swap i j]; sorry
|
||||
let ⟨a₂, h₂⟩ := heapifyDown lt a' j'
|
||||
⟨a₂, h₂.trans (a.size_swap i j)⟩
|
||||
termination_by _ => a.size - i
|
||||
decreasing_by assumption
|
||||
|
||||
@[simp] theorem size_heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
|
||||
(heapifyDown lt a i).1.size = a.size := (heapifyDown lt a i).2
|
||||
|
||||
/-- Core operation for binary heaps, expressed directly on arrays.
|
||||
Construct a heap from an unsorted array, by heapifying all the elements. -/
|
||||
def mkHeap (lt : α → α → Bool) (a : Array α) : {a' : Array α // a'.size = a.size} :=
|
||||
let rec loop : (i : Nat) → (a : Array α) → i ≤ a.size → {a' : Array α // a'.size = a.size}
|
||||
| 0, a, _ => ⟨a, rfl⟩
|
||||
| i+1, a, h =>
|
||||
let h := Nat.lt_of_succ_le h
|
||||
let a' := heapifyDown lt a ⟨i, h⟩
|
||||
let ⟨a₂, h₂⟩ := loop i a' ((heapifyDown ..).2.symm ▸ Nat.le_of_lt h)
|
||||
⟨a₂, h₂.trans a'.2⟩
|
||||
loop (a.size / 2) a sorry
|
||||
|
||||
@[simp] theorem size_mkHeap (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
|
||||
(mkHeap lt a).1.size = a.size := (mkHeap lt a).2
|
||||
|
||||
/-- Core operation for binary heaps, expressed directly on arrays.
|
||||
Given an array which is a max-heap, push item `i` up to restore the max-heap property. -/
|
||||
def heapifyUp (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
|
||||
{a' : Array α // a'.size = a.size} :=
|
||||
if i0 : i.1 = 0 then ⟨a, rfl⟩ else
|
||||
have : (i.1 - 1) / 2 < i := sorry
|
||||
let j := ⟨(i.1 - 1) / 2, Nat.lt_trans this i.2⟩
|
||||
if lt (a.get j) (a.get i) then
|
||||
let a' := a.swap i j
|
||||
let ⟨a₂, h₂⟩ := heapifyUp lt a' ⟨j.1, by rw [a.size_swap i j]; exact j.2⟩
|
||||
⟨a₂, h₂.trans (a.size_swap i j)⟩
|
||||
else ⟨a, rfl⟩
|
||||
termination_by _ => i.1
|
||||
decreasing_by assumption
|
||||
|
||||
@[simp] theorem size_heapifyUp (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
|
||||
(heapifyUp lt a i).1.size = a.size := (heapifyUp lt a i).2
|
||||
|
||||
/-- `O(1)`. Build a new empty heap. -/
|
||||
def empty (lt) : BinaryHeap α lt := ⟨#[]⟩
|
||||
|
||||
instance (lt) : Inhabited (BinaryHeap α lt) := ⟨empty _⟩
|
||||
instance (lt) : EmptyCollection (BinaryHeap α lt) := ⟨empty _⟩
|
||||
|
||||
/-- `O(1)`. Build a one-element heap. -/
|
||||
def singleton (lt) (x : α) : BinaryHeap α lt := ⟨#[x]⟩
|
||||
|
||||
/-- `O(1)`. Get the number of elements in a `BinaryHeap`. -/
|
||||
def size {lt} (self : BinaryHeap α lt) : Nat := self.1.size
|
||||
|
||||
/-- `O(1)`. Get an element in the heap by index. -/
|
||||
def get {lt} (self : BinaryHeap α lt) (i : Fin self.size) : α := self.1.get i
|
||||
|
||||
/-- `O(log n)`. Insert an element into a `BinaryHeap`, preserving the max-heap property. -/
|
||||
def insert {lt} (self : BinaryHeap α lt) (x : α) : BinaryHeap α lt where
|
||||
arr := let n := self.size;
|
||||
heapifyUp lt (self.1.push x) ⟨n, by rw [Array.size_push]; apply Nat.lt_succ_self⟩
|
||||
|
||||
@[simp] theorem size_insert {lt} (self : BinaryHeap α lt) (x : α) :
|
||||
(self.insert x).size = self.size + 1 := by
|
||||
simp [insert, size, size_heapifyUp]
|
||||
|
||||
/-- `O(1)`. Get the maximum element in a `BinaryHeap`. -/
|
||||
def max {lt} (self : BinaryHeap α lt) : Option α := self.1.get? 0
|
||||
|
||||
/-- Auxiliary for `popMax`. -/
|
||||
def popMaxAux {lt} (self : BinaryHeap α lt) : {a' : BinaryHeap α lt // a'.size = self.size - 1} :=
|
||||
match e: self.1.size with
|
||||
| 0 => ⟨self, by simp [size, e]⟩
|
||||
| n+1 =>
|
||||
have h0 := by rw [e]; apply Nat.succ_pos
|
||||
have hn := by rw [e]; apply Nat.lt_succ_self
|
||||
if hn0 : 0 < n then
|
||||
let a := self.1.swap ⟨0, h0⟩ ⟨n, hn⟩ |>.pop
|
||||
⟨⟨heapifyDown lt a ⟨0, sorry⟩⟩,
|
||||
by simp [size]⟩
|
||||
else
|
||||
⟨⟨self.1.pop⟩, by simp [size]⟩
|
||||
|
||||
/-- `O(log n)`. Remove the maximum element from a `BinaryHeap`.
|
||||
Call `max` first to actually retrieve the maximum element. -/
|
||||
def popMax {lt} (self : BinaryHeap α lt) : BinaryHeap α lt := self.popMaxAux
|
||||
|
||||
@[simp] theorem size_popMax {lt} (self : BinaryHeap α lt) :
|
||||
self.popMax.size = self.size - 1 := self.popMaxAux.2
|
||||
|
||||
/-- `O(log n)`. Return and remove the maximum element from a `BinaryHeap`. -/
|
||||
def extractMax {lt} (self : BinaryHeap α lt) : Option α × BinaryHeap α lt :=
|
||||
(self.max, self.popMax)
|
||||
|
||||
theorem size_pos_of_max {lt} {self : BinaryHeap α lt} (e : self.max = some x) : 0 < self.size :=
|
||||
Decidable.of_not_not fun h: ¬ 0 < self.1.size => by simp [BinaryHeap.max, Array.get?, h] at e
|
||||
|
||||
/-- `O(log n)`. Equivalent to `extractMax (self.insert x)`, except that extraction cannot fail. -/
|
||||
def insertExtractMax {lt} (self : BinaryHeap α lt) (x : α) : α × BinaryHeap α lt :=
|
||||
match e: self.max with
|
||||
| none => (x, self)
|
||||
| some m =>
|
||||
if lt x m then
|
||||
let a := self.1.set ⟨0, size_pos_of_max e⟩ x
|
||||
(m, ⟨heapifyDown lt a ⟨0, by simp; exact size_pos_of_max e⟩⟩)
|
||||
else (x, self)
|
||||
|
||||
/-- `O(log n)`. Equivalent to `(self.max, self.popMax.insert x)`. -/
|
||||
def replaceMax {lt} (self : BinaryHeap α lt) (x : α) : Option α × BinaryHeap α lt :=
|
||||
match e: self.max with
|
||||
| none => (none, ⟨self.1.push x⟩)
|
||||
| some m =>
|
||||
let a := self.1.set ⟨0, size_pos_of_max e⟩ x
|
||||
(some m, ⟨heapifyDown lt a ⟨0, by simp; exact size_pos_of_max e⟩⟩)
|
||||
|
||||
/-- `O(log n)`. Replace the value at index `i` by `x`. Assumes that `x ≤ self.get i`. -/
|
||||
def decreaseKey {lt} (self : BinaryHeap α lt) (i : Fin self.size) (x : α) : BinaryHeap α lt where
|
||||
arr := heapifyDown lt (self.1.set i x) ⟨i, by rw [self.1.size_set]; exact i.2⟩
|
||||
|
||||
/-- `O(log n)`. Replace the value at index `i` by `x`. Assumes that `self.get i ≤ x`. -/
|
||||
def increaseKey {lt} (self : BinaryHeap α lt) (i : Fin self.size) (x : α) : BinaryHeap α lt where
|
||||
arr := heapifyUp lt (self.1.set i x) ⟨i, by rw [self.1.size_set]; exact i.2⟩
|
||||
|
||||
end BinaryHeap
|
||||
|
||||
/-- `O(n)`. Convert an unsorted array to a `BinaryHeap`. -/
|
||||
def Array.toBinaryHeap (lt : α → α → Bool) (a : Array α) : BinaryHeap α lt where
|
||||
arr := BinaryHeap.mkHeap lt a
|
||||
|
||||
/-- `O(n log n)`. Sort an array using a `BinaryHeap`. -/
|
||||
@[specialize] def Array.heapSort (a : Array α) (lt : α → α → Bool) : Array α :=
|
||||
let gt y x := lt x y
|
||||
let rec loop (a : BinaryHeap α gt) (out : Array α) : Array α :=
|
||||
match e:a.max with
|
||||
| none => out
|
||||
| some x =>
|
||||
have : a.popMax.size < a.size := by
|
||||
simp; exact Nat.sub_lt (BinaryHeap.size_pos_of_max e) Nat.zero_lt_one
|
||||
loop a.popMax (out.push x)
|
||||
loop (a.toBinaryHeap gt) #[]
|
||||
termination_by _ => a.size
|
||||
decreasing_by assumption
|
||||
|
||||
attribute [simp] Array.heapSort.loop
|
||||
#check @Array.heapSort.loop._eq_1
|
||||
Loading…
Add table
Reference in a new issue